From f23612c3a3e5624051b8653389ca1a192c5b4875 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Wed, 20 Dec 2017 15:09:37 +0100 Subject: [PATCH] Add ResolvableType to HttpEntity for multipart Publishers This commit adds a ResolvableType field to HttpEntity, in order to support Publishers as multipart data. Without the type, the MultipartHttpMessageWriter does not know which delegate writer to use to write the part. Issue: SPR-16307 --- .../org/springframework/http/HttpEntity.java | 66 ++++++++++++ .../http/client/MultipartBodyBuilder.java | 102 +++++++++++++++++- .../multipart/MultipartHttpMessageWriter.java | 20 +++- .../client/MultipartBodyBuilderTests.java | 21 +++- .../MultipartHttpMessageWriterTests.java | 27 +++-- 5 files changed, 216 insertions(+), 20 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/http/HttpEntity.java b/spring-web/src/main/java/org/springframework/http/HttpEntity.java index 38c1d3fddd..a709bf7860 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpEntity.java +++ b/spring-web/src/main/java/org/springframework/http/HttpEntity.java @@ -16,7 +16,12 @@ package org.springframework.http; +import org.reactivestreams.Publisher; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; import org.springframework.util.ObjectUtils; @@ -67,6 +72,9 @@ public class HttpEntity { @Nullable private final T body; + @Nullable + private final ResolvableType bodyType; + /** * Create a new, empty {@code HttpEntity}. @@ -97,7 +105,18 @@ public class HttpEntity { * @param headers the entity headers */ public HttpEntity(@Nullable T body, @Nullable MultiValueMap headers) { + this(body, null, headers); + } + + private HttpEntity(@Nullable T body, @Nullable ResolvableType bodyType, + @Nullable MultiValueMap headers) { this.body = body; + + if (bodyType == null && body != null) { + bodyType = ResolvableType.forClass(body.getClass()); + } + this.bodyType = bodyType ; + HttpHeaders tempHeaders = new HttpHeaders(); if (headers != null) { tempHeaders.putAll(headers); @@ -128,6 +147,13 @@ public class HttpEntity { return (this.body != null); } + /** + * Returns the type of the body. + */ + @Nullable + public ResolvableType getBodyType() { + return this.bodyType; + } @Override public boolean equals(@Nullable Object other) { @@ -159,4 +185,44 @@ public class HttpEntity { return builder.toString(); } + + // Static builder methods + + /** + * Create a new {@code HttpEntity} with the given {@link Publisher} as body, class contained in + * {@code publisher}, and headers. + * @param publisher the publisher to use as body + * @param elementClass the class of elements contained in the publisher + * @param headers the entity headers + * @param the type of the elements contained in the publisher + * @param

the type of the {@code Publisher} + * @return the created entity + */ + public static > HttpEntity

fromPublisher(P publisher, + Class elementClass, @Nullable MultiValueMap headers) { + + Assert.notNull(publisher, "'publisher' must not be null"); + Assert.notNull(elementClass, "'elementClass' must not be null"); + return new HttpEntity<>(publisher, ResolvableType.forClass(elementClass), headers); + } + + /** + * Create a new {@code HttpEntity} with the given {@link Publisher} as body, type contained in + * {@code publisher}, and headers. + * @param publisher the publisher to use as body + * @param typeReference the type of elements contained in the publisher + * @param headers the entity headers + * @param the type of the elements contained in the publisher + * @param

the type of the {@code Publisher} + * @return the created entity + */ + public static > HttpEntity

fromPublisher(P publisher, + ParameterizedTypeReference typeReference, + @Nullable MultiValueMap headers) { + + Assert.notNull(publisher, "'publisher' must not be null"); + Assert.notNull(typeReference, "'typeReference' must not be null"); + return new HttpEntity<>(publisher, ResolvableType.forType(typeReference), headers); + } + } diff --git a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java index 593020037c..231a7a52e2 100644 --- a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java +++ b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java @@ -20,6 +20,10 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import org.reactivestreams.Publisher; + +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; @@ -96,6 +100,11 @@ public final class MultipartBodyBuilder { Assert.hasLength(name, "'name' must not be empty"); Assert.notNull(part, "'part' must not be null"); + if (part instanceof Publisher) { + throw new IllegalArgumentException("Use publisher(String, Publisher, Class) or " + + "publisher(String, Publisher, ParameterizedTypeReference) for adding Publisher parts"); + } + Object partBody; HttpHeaders partHeaders = new HttpHeaders(); @@ -116,6 +125,54 @@ public final class MultipartBodyBuilder { return builder; } + /** + * Adds a {@link Publisher} part to this builder, allowing for further header customization with + * the returned {@link PartBuilder}. + * @param name the name of the part to add (may not be empty) + * @param publisher the contents of the part to add + * @param elementClass the class of elements contained in the publisher + * @return a builder that allows for further header customization + */ + public > PartBuilder asyncPart(String name, P publisher, + Class elementClass) { + + Assert.notNull(elementClass, "'elementClass' must not be null"); + ResolvableType elementType = ResolvableType.forClass(elementClass); + Assert.hasLength(name, "'name' must not be empty"); + Assert.notNull(publisher, "'publisher' must not be null"); + Assert.notNull(elementType, "'elementType' must not be null"); + + HttpHeaders partHeaders = new HttpHeaders(); + PublisherClassPartBuilder builder = + new PublisherClassPartBuilder<>(publisher, elementClass, partHeaders); + this.parts.add(name, builder); + return builder; + + } + + /** + * Adds a {@link Publisher} part to this builder, allowing for further header customization with + * the returned {@link PartBuilder}. + * @param name the name of the part to add (may not be empty) + * @param publisher the contents of the part to add + * @param elementType the type of elements contained in the publisher + * @return a builder that allows for further header customization + */ + public > PartBuilder asyncPart(String name, P publisher, + ParameterizedTypeReference elementType) { + + Assert.notNull(elementType, "'elementType' must not be null"); + ResolvableType elementType1 = ResolvableType.forType(elementType); + Assert.hasLength(name, "'name' must not be empty"); + Assert.notNull(publisher, "'publisher' must not be null"); + Assert.notNull(elementType1, "'elementType' must not be null"); + + HttpHeaders partHeaders = new HttpHeaders(); + PublisherTypReferencePartBuilder builder = + new PublisherTypReferencePartBuilder<>(publisher, elementType, partHeaders); + this.parts.add(name, builder); + return builder; + } /** * Builder interface that allows for customization of part headers. @@ -136,10 +193,9 @@ public final class MultipartBodyBuilder { private static class DefaultPartBuilder implements PartBuilder { @Nullable - private final Object body; - - private final HttpHeaders headers; + protected final Object body; + protected final HttpHeaders headers; public DefaultPartBuilder(@Nullable Object body, HttpHeaders headers) { this.body = body; @@ -157,4 +213,44 @@ public final class MultipartBodyBuilder { } } + private static class PublisherClassPartBuilder> + extends DefaultPartBuilder { + + private final Class bodyType; + + public PublisherClassPartBuilder(P body, Class bodyType, HttpHeaders headers) { + super(body, headers); + this.bodyType = bodyType; + } + + @Override + @SuppressWarnings("unchecked") + public HttpEntity build() { + P body = (P) this.body; + Assert.state(body != null, "'body' must not be null"); + return HttpEntity.fromPublisher(body, this.bodyType, this.headers); + } + } + + private static class PublisherTypReferencePartBuilder> + extends DefaultPartBuilder { + + private final ParameterizedTypeReference bodyType; + + public PublisherTypReferencePartBuilder(P body, ParameterizedTypeReference bodyType, + HttpHeaders headers) { + + super(body, headers); + this.bodyType = bodyType; + } + + @Override + @SuppressWarnings("unchecked") + public HttpEntity build() { + P body = (P) this.body; + Assert.state(body != null, "'body' must not be null"); + return HttpEntity.fromPublisher(body, this.bodyType, this.headers); + } + } + } diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java index 03008459a6..5923ac7eb0 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java @@ -230,31 +230,41 @@ public class MultipartHttpMessageWriter implements HttpMessageWriter) value).getHeaders()); - body = ((HttpEntity) value).getBody(); + HttpEntity httpEntity = (HttpEntity) value; + outputMessage.getHeaders().putAll(httpEntity.getHeaders()); + body = httpEntity.getBody(); Assert.state(body != null, "MultipartHttpMessageWriter only supports HttpEntity with body"); + bodyType = httpEntity.getBodyType(); } else { body = value; } + if (bodyType == null) { + bodyType = ResolvableType.forClass(body.getClass()); + } + String filename = (body instanceof Resource ? ((Resource) body).getFilename() : null); outputMessage.getHeaders().setContentDispositionFormData(name, filename); - ResolvableType bodyType = ResolvableType.forClass(body.getClass()); MediaType contentType = outputMessage.getHeaders().getContentType(); + final ResolvableType finalBodyType = bodyType; Optional> writer = this.partWriters.stream() - .filter(partWriter -> partWriter.canWrite(bodyType, contentType)) + .filter(partWriter -> partWriter.canWrite(finalBodyType, contentType)) .findFirst(); if (!writer.isPresent()) { return Flux.error(new CodecException("No suitable writer found for part: " + name)); } + Publisher bodyPublisher = + body instanceof Publisher ? (Publisher) body : Mono.just(body); + Mono partWritten = ((HttpMessageWriter) writer.get()) - .write(Mono.just(body), bodyType, contentType, outputMessage, Collections.emptyMap()); + .write(bodyPublisher, bodyType, contentType, outputMessage, Collections.emptyMap()); // partWritten.subscribe() is required in order to make sure MultipartHttpOutputMessage#getBody() // returns a non-null value (occurs with ResourceHttpMessageWriter that invokes diff --git a/spring-web/src/test/java/org/springframework/http/client/MultipartBodyBuilderTests.java b/spring-web/src/test/java/org/springframework/http/client/MultipartBodyBuilderTests.java index 9293d02905..ed42d016c6 100644 --- a/spring-web/src/test/java/org/springframework/http/client/MultipartBodyBuilderTests.java +++ b/spring-web/src/test/java/org/springframework/http/client/MultipartBodyBuilderTests.java @@ -17,7 +17,10 @@ package org.springframework.http.client; import org.junit.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import org.springframework.core.ResolvableType; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.http.HttpEntity; @@ -34,23 +37,25 @@ public class MultipartBodyBuilderTests { @Test public void builder() throws Exception { - MultiValueMap form = new LinkedMultiValueMap<>(); - form.add("form field", "form value"); + MultiValueMap multipartData = new LinkedMultiValueMap<>(); + multipartData.add("form field", "form value"); Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg"); HttpHeaders entityHeaders = new HttpHeaders(); entityHeaders.add("foo", "bar"); HttpEntity entity = new HttpEntity<>("body", entityHeaders); + Publisher publisher = Flux.just("foo", "bar", "baz"); MultipartBodyBuilder builder = new MultipartBodyBuilder(); - builder.part("key", form).header("foo", "bar"); + builder.part("key", multipartData).header("foo", "bar"); builder.part("logo", logo).header("baz", "qux"); builder.part("entity", entity).header("baz", "qux"); + builder.asyncPart("publisher", publisher, String.class).header("baz", "qux"); MultiValueMap> result = builder.build(); - assertEquals(3, result.size()); + assertEquals(4, result.size()); assertNotNull(result.getFirst("key")); - assertEquals(form, result.getFirst("key").getBody()); + assertEquals(multipartData, result.getFirst("key").getBody()); assertEquals("bar", result.getFirst("key").getHeaders().getFirst("foo")); assertNotNull(result.getFirst("logo")); @@ -61,6 +66,12 @@ public class MultipartBodyBuilderTests { assertEquals("body", result.getFirst("entity").getBody()); assertEquals("bar", result.getFirst("entity").getHeaders().getFirst("foo")); assertEquals("qux", result.getFirst("entity").getHeaders().getFirst("baz")); + + assertNotNull(result.getFirst("publisher")); + assertEquals(publisher, result.getFirst("publisher").getBody()); + assertEquals(ResolvableType.forClass(String.class), result.getFirst("publisher").getBodyType()); + assertEquals("bar", result.getFirst("entity").getHeaders().getFirst("foo")); + assertEquals("qux", result.getFirst("entity").getHeaders().getFirst("baz")); } diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java index 6f0bf30939..1d42d34d6f 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java @@ -22,6 +22,8 @@ import java.util.List; import java.util.Map; import org.junit.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.core.ResolvableType; @@ -36,10 +38,7 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import org.springframework.util.MultiValueMap; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * @author Sebastien Deleuze @@ -80,6 +79,8 @@ public class MultipartHttpMessageWriterTests { } }; + Publisher publisher = Flux.just("foo", "bar", "baz"); + MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); bodyBuilder.part("name 1", "value 1"); bodyBuilder.part("name 2", "value 2+1"); @@ -87,14 +88,15 @@ public class MultipartHttpMessageWriterTests { bodyBuilder.part("logo", logo); bodyBuilder.part("utf8", utf8); bodyBuilder.part("json", new Foo("bar"), MediaType.APPLICATION_JSON_UTF8); - Mono>> publisher = Mono.just(bodyBuilder.build()); + bodyBuilder.asyncPart("publisher", publisher, String.class); + Mono>> result = Mono.just(bodyBuilder.build()); MockServerHttpResponse response = new MockServerHttpResponse(); Map hints = Collections.emptyMap(); - this.writer.write(publisher, null, MediaType.MULTIPART_FORM_DATA, response, hints).block(Duration.ofSeconds(5)); + this.writer.write(result, null, MediaType.MULTIPART_FORM_DATA, response, hints).block(Duration.ofSeconds(5)); MultiValueMap requestParts = parse(response, hints); - assertEquals(5, requestParts.size()); + assertEquals(6, requestParts.size()); Part part = requestParts.getFirst("name 1"); assertTrue(part instanceof FormFieldPart); @@ -136,6 +138,17 @@ public class MultipartHttpMessageWriterTests { assertEquals("{\"bar\":\"bar\"}", value); + part = requestParts.getFirst("publisher"); + assertEquals("publisher", part.name()); + + value = StringDecoder.textPlainOnly(false).decodeToMono(part.content(), + ResolvableType.forClass(String.class), MediaType.TEXT_PLAIN, + Collections.emptyMap()).block(Duration.ZERO); + + assertEquals("foobarbaz", value); + + + } private MultiValueMap parse(MockServerHttpResponse response, Map hints) {