diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java index 350d693cd3..602c8f1431 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java @@ -17,11 +17,9 @@ package org.springframework.http.codec.multipart; import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -49,11 +47,9 @@ import org.springframework.http.ReactiveHttpOutputMessage; import org.springframework.http.codec.EncoderHttpMessageWriter; import org.springframework.http.codec.FormHttpMessageWriter; import org.springframework.http.codec.HttpMessageWriter; -import org.springframework.http.codec.LoggingCodecSupport; import org.springframework.http.codec.ResourceHttpMessageWriter; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import org.springframework.util.MimeTypeUtils; import org.springframework.util.MultiValueMap; /** @@ -77,14 +73,9 @@ import org.springframework.util.MultiValueMap; * @since 5.0 * @see FormHttpMessageWriter */ -public class MultipartHttpMessageWriter extends LoggingCodecSupport +public class MultipartHttpMessageWriter extends MultipartWriterSupport implements HttpMessageWriter> { - /** - * THe default charset used by the writer. - */ - public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; - /** Suppress logging from individual part writers (full map logged at this level). */ private static final Map DEFAULT_HINTS = Hints.from(Hints.SUPPRESS_LOGGING_HINT, true); @@ -94,10 +85,6 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport @Nullable private final HttpMessageWriter> formWriter; - private Charset charset = DEFAULT_CHARSET; - - private final List supportedMediaTypes; - /** * Constructor with a default list of part writers (String and Resource). @@ -126,9 +113,9 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport public MultipartHttpMessageWriter(List> partWriters, @Nullable HttpMessageWriter> formWriter) { + super(initMediaTypes(formWriter)); this.partWriters = partWriters; this.formWriter = formWriter; - this.supportedMediaTypes = initMediaTypes(formWriter); } private static List initMediaTypes(@Nullable HttpMessageWriter formWriter) { @@ -168,25 +155,6 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport this.charset = charset; } - /** - * Return the configured charset for part headers. - */ - public Charset getCharset() { - return this.charset; - } - - - @Override - public List getWritableMediaTypes() { - return this.supportedMediaTypes; - } - - @Override - public boolean canWrite(ResolvableType elementType, @Nullable MediaType mediaType) { - return (MultiValueMap.class.isAssignableFrom(elementType.toClass()) && - (mediaType == null || - this.supportedMediaTypes.stream().anyMatch(element -> element.isCompatibleWith(mediaType)))); - } @Override public Mono write(Publisher> inputStream, @@ -225,16 +193,7 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport byte[] boundary = generateMultipartBoundary(); - Map params = new HashMap<>(); - if (mediaType != null) { - params.putAll(mediaType.getParameters()); - } - params.put("boundary", new String(boundary, StandardCharsets.US_ASCII)); - params.put("charset", getCharset().name()); - - mediaType = (mediaType != null ? mediaType : MediaType.MULTIPART_FORM_DATA); - mediaType = new MediaType(mediaType, params); - + mediaType = getMultipartMediaType(mediaType, boundary); outputMessage.getHeaders().setContentType(mediaType); LogFormatUtils.traceDebug(logger, traceOn -> Hints.getLogPrefix(hints) + "Encoding " + @@ -252,14 +211,6 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport return outputMessage.writeWith(body); } - /** - * Generate a multipart boundary. - *

By default delegates to {@link MimeTypeUtils#generateMultipartBoundary()}. - */ - protected byte[] generateMultipartBoundary() { - return MimeTypeUtils.generateMultipartBoundary(); - } - private Flux encodePartValues( byte[] boundary, String name, List values, DataBufferFactory bufferFactory) { @@ -268,15 +219,15 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport } @SuppressWarnings("unchecked") - private Flux encodePart(byte[] boundary, String name, T value, DataBufferFactory bufferFactory) { - MultipartHttpOutputMessage outputMessage = new MultipartHttpOutputMessage(bufferFactory, getCharset()); - HttpHeaders outputHeaders = outputMessage.getHeaders(); + private Flux encodePart(byte[] boundary, String name, T value, DataBufferFactory factory) { + MultipartHttpOutputMessage message = new MultipartHttpOutputMessage(factory); + HttpHeaders headers = message.getHeaders(); T body; ResolvableType resolvableType = null; if (value instanceof HttpEntity) { HttpEntity httpEntity = (HttpEntity) value; - outputHeaders.putAll(httpEntity.getHeaders()); + headers.putAll(httpEntity.getHeaders()); body = httpEntity.getBody(); Assert.state(body != null, "MultipartHttpMessageWriter only supports HttpEntity with body"); if (httpEntity instanceof ResolvableTypeProvider) { @@ -290,20 +241,20 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport resolvableType = ResolvableType.forClass(body.getClass()); } - if (!outputHeaders.containsKey(HttpHeaders.CONTENT_DISPOSITION)) { + if (!headers.containsKey(HttpHeaders.CONTENT_DISPOSITION)) { if (body instanceof Resource) { - outputHeaders.setContentDispositionFormData(name, ((Resource) body).getFilename()); + headers.setContentDispositionFormData(name, ((Resource) body).getFilename()); } else if (resolvableType.resolve() == Resource.class) { - body = (T) Mono.from((Publisher) body).doOnNext(o -> outputHeaders + body = (T) Mono.from((Publisher) body).doOnNext(o -> headers .setContentDispositionFormData(name, ((Resource) o).getFilename())); } else { - outputHeaders.setContentDispositionFormData(name, null); + headers.setContentDispositionFormData(name, null); } } - MediaType contentType = outputHeaders.getContentType(); + MediaType contentType = headers.getContentType(); final ResolvableType finalBodyType = resolvableType; Optional> writer = this.partWriters.stream() @@ -321,62 +272,24 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport // but only stores the body Flux and returns Mono.empty(). Mono partContentReady = ((HttpMessageWriter) writer.get()) - .write(bodyPublisher, resolvableType, contentType, outputMessage, DEFAULT_HINTS); + .write(bodyPublisher, resolvableType, contentType, message, DEFAULT_HINTS); // After partContentReady, we can access the part content from MultipartHttpOutputMessage // and use it for writing to the actual request body - Flux partContent = partContentReady.thenMany(Flux.defer(outputMessage::getBody)); + Flux partContent = partContentReady.thenMany(Flux.defer(message::getBody)); return Flux.concat( - generateBoundaryLine(boundary, bufferFactory), + generateBoundaryLine(boundary, factory), partContent, - generateNewLine(bufferFactory)); + generateNewLine(factory)); } - private Mono generateBoundaryLine(byte[] boundary, DataBufferFactory bufferFactory) { - return Mono.fromCallable(() -> { - DataBuffer buffer = bufferFactory.allocateBuffer(boundary.length + 4); - buffer.write((byte)'-'); - buffer.write((byte)'-'); - buffer.write(boundary); - buffer.write((byte)'\r'); - buffer.write((byte)'\n'); - return buffer; - }); - } - - private Mono generateNewLine(DataBufferFactory bufferFactory) { - return Mono.fromCallable(() -> { - DataBuffer buffer = bufferFactory.allocateBuffer(2); - buffer.write((byte)'\r'); - buffer.write((byte)'\n'); - return buffer; - }); - } - - private Mono generateLastLine(byte[] boundary, DataBufferFactory bufferFactory) { - return Mono.fromCallable(() -> { - DataBuffer buffer = bufferFactory.allocateBuffer(boundary.length + 6); - buffer.write((byte)'-'); - buffer.write((byte)'-'); - buffer.write(boundary); - buffer.write((byte)'-'); - buffer.write((byte)'-'); - buffer.write((byte)'\r'); - buffer.write((byte)'\n'); - return buffer; - }); - } - - - private static class MultipartHttpOutputMessage implements ReactiveHttpOutputMessage { + private class MultipartHttpOutputMessage implements ReactiveHttpOutputMessage { private final DataBufferFactory bufferFactory; - private final Charset charset; - private final HttpHeaders headers = new HttpHeaders(); private final AtomicBoolean committed = new AtomicBoolean(); @@ -384,9 +297,8 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport @Nullable private Flux body; - public MultipartHttpOutputMessage(DataBufferFactory bufferFactory, Charset charset) { + public MultipartHttpOutputMessage(DataBufferFactory bufferFactory) { this.bufferFactory = bufferFactory; - this.charset = charset; } @Override @@ -414,33 +326,12 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport if (this.body != null) { return Mono.error(new IllegalStateException("Multiple calls to writeWith() not supported")); } - this.body = generateHeaders().concatWith(body); + this.body = generatePartHeaders(this.headers, this.bufferFactory).concatWith(body); // We don't actually want to write (just save the body Flux) return Mono.empty(); } - private Mono generateHeaders() { - return Mono.fromCallable(() -> { - DataBuffer buffer = this.bufferFactory.allocateBuffer(); - for (Map.Entry> entry : this.headers.entrySet()) { - byte[] headerName = entry.getKey().getBytes(this.charset); - for (String headerValueString : entry.getValue()) { - byte[] headerValue = headerValueString.getBytes(this.charset); - buffer.write(headerName); - buffer.write((byte)':'); - buffer.write((byte)' '); - buffer.write(headerValue); - buffer.write((byte)'\r'); - buffer.write((byte)'\n'); - } - } - buffer.write((byte)'\r'); - buffer.write((byte)'\n'); - return buffer; - }); - } - @Override public Mono writeAndFlushWith(Publisher> body) { return Mono.error(new UnsupportedOperationException()); diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartWriterSupport.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartWriterSupport.java new file mode 100644 index 0000000000..d7cf518121 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartWriterSupport.java @@ -0,0 +1,168 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.codec.multipart; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.codec.LoggingCodecSupport; +import org.springframework.lang.Nullable; +import org.springframework.util.MimeTypeUtils; +import org.springframework.util.MultiValueMap; + +/** + * Support class for multipart HTTP message writers. + * + * @author Rossen Stoyanchev + * @since 5.3 + */ +public class MultipartWriterSupport extends LoggingCodecSupport { + + /** THe default charset used by the writer. */ + public static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + protected final List supportedMediaTypes; + + protected Charset charset = DEFAULT_CHARSET; + + + /** + * Constructor with the list of supported media types. + */ + protected MultipartWriterSupport(List supportedMediaTypes) { + this.supportedMediaTypes = supportedMediaTypes; + } + + + /** + * Return the configured charset for part headers. + */ + public Charset getCharset() { + return this.charset; + } + + public List getWritableMediaTypes() { + return this.supportedMediaTypes; + } + + + public boolean canWrite(ResolvableType elementType, @Nullable MediaType mediaType) { + if (MultiValueMap.class.isAssignableFrom(elementType.toClass())) { + if (mediaType == null) { + return true; + } + for (MediaType supportedMediaType : this.supportedMediaTypes) { + if (supportedMediaType.isCompatibleWith(mediaType)) { + return true; + } + } + } + return false; + } + + /** + * Generate a multipart boundary. + *

By default delegates to {@link MimeTypeUtils#generateMultipartBoundary()}. + */ + protected byte[] generateMultipartBoundary() { + return MimeTypeUtils.generateMultipartBoundary(); + } + + /** + * Prepare the {@code MediaType} to use by adding "boundary" and "charset" + * parameters to the given {@code mediaType} or "mulitpart/form-data" + * otherwise by default. + */ + protected MediaType getMultipartMediaType(@Nullable MediaType mediaType, byte[] boundary) { + Map params = new HashMap<>(); + if (mediaType != null) { + params.putAll(mediaType.getParameters()); + } + params.put("boundary", new String(boundary, StandardCharsets.US_ASCII)); + params.put("charset", getCharset().name()); + + mediaType = (mediaType != null ? mediaType : MediaType.MULTIPART_FORM_DATA); + mediaType = new MediaType(mediaType, params); + return mediaType; + } + + protected Mono generateBoundaryLine(byte[] boundary, DataBufferFactory bufferFactory) { + return Mono.fromCallable(() -> { + DataBuffer buffer = bufferFactory.allocateBuffer(boundary.length + 4); + buffer.write((byte)'-'); + buffer.write((byte)'-'); + buffer.write(boundary); + buffer.write((byte)'\r'); + buffer.write((byte)'\n'); + return buffer; + }); + } + + protected Mono generateNewLine(DataBufferFactory bufferFactory) { + return Mono.fromCallable(() -> { + DataBuffer buffer = bufferFactory.allocateBuffer(2); + buffer.write((byte)'\r'); + buffer.write((byte)'\n'); + return buffer; + }); + } + + protected Mono generateLastLine(byte[] boundary, DataBufferFactory bufferFactory) { + return Mono.fromCallable(() -> { + DataBuffer buffer = bufferFactory.allocateBuffer(boundary.length + 6); + buffer.write((byte)'-'); + buffer.write((byte)'-'); + buffer.write(boundary); + buffer.write((byte)'-'); + buffer.write((byte)'-'); + buffer.write((byte)'\r'); + buffer.write((byte)'\n'); + return buffer; + }); + } + + protected Mono generatePartHeaders(HttpHeaders headers, DataBufferFactory bufferFactory) { + return Mono.fromCallable(() -> { + DataBuffer buffer = bufferFactory.allocateBuffer(); + for (Map.Entry> entry : headers.entrySet()) { + byte[] headerName = entry.getKey().getBytes(getCharset()); + for (String headerValueString : entry.getValue()) { + byte[] headerValue = headerValueString.getBytes(getCharset()); + buffer.write(headerName); + buffer.write((byte)':'); + buffer.write((byte)' '); + buffer.write(headerValue); + buffer.write((byte)'\r'); + buffer.write((byte)'\n'); + } + } + buffer.write((byte)'\r'); + buffer.write((byte)'\n'); + return buffer; + }); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/PartHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/PartHttpMessageWriter.java new file mode 100644 index 0000000000..d1d9f49916 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/PartHttpMessageWriter.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.codec.multipart; + +import java.util.Map; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.Hints; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.PooledDataBuffer; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpOutputMessage; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.lang.Nullable; + +/** + * {@link HttpMessageWriter} for writing with {@link Part}. This can be useful + * on the server side to write a {@code Flux} received from a client to + * some remote service. + * + * @author Rossen Stoyanchev + * @since 5.3 + */ +public class PartHttpMessageWriter extends MultipartWriterSupport implements HttpMessageWriter { + + + public PartHttpMessageWriter() { + super(MultipartHttpMessageReader.MIME_TYPES); + } + + + @Override + public Mono write(Publisher parts, + ResolvableType elementType, @Nullable MediaType mediaType, ReactiveHttpOutputMessage outputMessage, + Map hints) { + + byte[] boundary = generateMultipartBoundary(); + + mediaType = getMultipartMediaType(mediaType, boundary); + outputMessage.getHeaders().setContentType(mediaType); + + if (logger.isDebugEnabled()) { + logger.debug(Hints.getLogPrefix(hints) + "Encoding Publisher"); + } + + Flux body = Flux.from(parts) + .concatMap(part -> encodePart(boundary, part, outputMessage.bufferFactory())) + .concatWith(generateLastLine(boundary, outputMessage.bufferFactory())) + .doOnDiscard(PooledDataBuffer.class, PooledDataBuffer::release); + + return outputMessage.writeWith(body); + } + + private Flux encodePart(byte[] boundary, Part part, DataBufferFactory bufferFactory) { + HttpHeaders headers = new HttpHeaders(part.headers()); + + String name = part.name(); + if (!headers.containsKey(HttpHeaders.CONTENT_DISPOSITION)) { + headers.setContentDispositionFormData(name, + (part instanceof FilePart ? ((FilePart) part).filename() : null)); + } + + return Flux.concat( + generateBoundaryLine(boundary, bufferFactory), + generatePartHeaders(headers, bufferFactory), + part.content(), + generateNewLine(bufferFactory)); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/ServerDefaultCodecsImpl.java b/spring-web/src/main/java/org/springframework/http/codec/support/ServerDefaultCodecsImpl.java index eaab4e3237..a50b61e10e 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/support/ServerDefaultCodecsImpl.java +++ b/spring-web/src/main/java/org/springframework/http/codec/support/ServerDefaultCodecsImpl.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import org.springframework.http.codec.HttpMessageWriter; import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.http.codec.ServerSentEventHttpMessageWriter; import org.springframework.http.codec.multipart.MultipartHttpMessageReader; +import org.springframework.http.codec.multipart.PartHttpMessageWriter; import org.springframework.http.codec.multipart.SynchronossPartHttpMessageReader; import org.springframework.lang.Nullable; @@ -74,6 +75,11 @@ class ServerDefaultCodecsImpl extends BaseDefaultCodecs implements ServerCodecCo } } + @Override + protected void extendTypedWriters(List> typedWriters) { + addCodec(typedWriters, new PartHttpMessageWriter()); + } + @Override protected void extendObjectWriters(List> objectWriters) { objectWriters.add(new ServerSentEventHttpMessageWriter(getSseEncoder())); diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java index 7a282bfdd4..86036bf60a 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java @@ -49,6 +49,8 @@ import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; /** + * Unit tests for {@link MultipartHttpMessageWriter}. + * * @author Sebastien Deleuze * @author Rossen Stoyanchev */ @@ -118,39 +120,34 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTests { this.writer.write(result, null, MediaType.MULTIPART_FORM_DATA, this.response, hints) .block(Duration.ofSeconds(5)); - MultiValueMap requestParts = parse(hints); + MultiValueMap requestParts = parse(this.response, hints); assertThat(requestParts.size()).isEqualTo(7); Part part = requestParts.getFirst("name 1"); - boolean condition4 = part instanceof FormFieldPart; - assertThat(condition4).isTrue(); + assertThat(part instanceof FormFieldPart).isTrue(); assertThat(part.name()).isEqualTo("name 1"); assertThat(((FormFieldPart) part).value()).isEqualTo("value 1"); List parts2 = requestParts.get("name 2"); assertThat(parts2.size()).isEqualTo(2); part = parts2.get(0); - boolean condition3 = part instanceof FormFieldPart; - assertThat(condition3).isTrue(); + assertThat(part instanceof FormFieldPart).isTrue(); assertThat(part.name()).isEqualTo("name 2"); assertThat(((FormFieldPart) part).value()).isEqualTo("value 2+1"); part = parts2.get(1); - boolean condition2 = part instanceof FormFieldPart; - assertThat(condition2).isTrue(); + assertThat(part instanceof FormFieldPart).isTrue(); assertThat(part.name()).isEqualTo("name 2"); assertThat(((FormFieldPart) part).value()).isEqualTo("value 2+2"); part = requestParts.getFirst("logo"); - boolean condition1 = part instanceof FilePart; - assertThat(condition1).isTrue(); + assertThat(part instanceof FilePart).isTrue(); assertThat(part.name()).isEqualTo("logo"); assertThat(((FilePart) part).filename()).isEqualTo("logo.jpg"); assertThat(part.headers().getContentType()).isEqualTo(MediaType.IMAGE_JPEG); assertThat(part.headers().getContentLength()).isEqualTo(logo.getFile().length()); part = requestParts.getFirst("utf8"); - boolean condition = part instanceof FilePart; - assertThat(condition).isTrue(); + assertThat(part instanceof FilePart).isTrue(); assertThat(part.name()).isEqualTo("utf8"); assertThat(((FilePart) part).filename()).isEqualTo("Hall\u00F6le.jpg"); assertThat(part.headers().getContentType()).isEqualTo(MediaType.IMAGE_JPEG); @@ -195,7 +192,7 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTests { assertThat(contentType.getParameter("boundary")).isNotEmpty(); assertThat(contentType.getParameter("charset")).isEqualTo("UTF-8"); - MultiValueMap requestParts = parse(hints); + MultiValueMap requestParts = parse(this.response, hints); assertThat(requestParts.size()).isEqualTo(2); assertThat(requestParts.getFirst("name 1").name()).isEqualTo("name 1"); assertThat(requestParts.getFirst("name 2").name()).isEqualTo("name 2"); @@ -222,13 +219,12 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTests { Map hints = Collections.emptyMap(); this.writer.write(result, null, MediaType.MULTIPART_FORM_DATA, this.response, hints).block(); - MultiValueMap requestParts = parse(hints); + MultiValueMap requestParts = parse(this.response, hints); assertThat(requestParts.size()).isEqualTo(1); Part part = requestParts.getFirst("logo"); assertThat(part.name()).isEqualTo("logo"); - boolean condition = part instanceof FilePart; - assertThat(condition).isTrue(); + assertThat(part instanceof FilePart).isTrue(); assertThat(((FilePart) part).filename()).isEqualTo("logo.jpg"); assertThat(part.headers().getContentType()).isEqualTo(MediaType.IMAGE_JPEG); assertThat(part.headers().getContentLength()).isEqualTo(logo.getFile().length()); @@ -273,24 +269,22 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTests { this.writer.write(Mono.just(multipartData), null, MediaType.MULTIPART_FORM_DATA, this.response, hints).block(); - MultiValueMap requestParts = parse(hints); + MultiValueMap requestParts = parse(this.response, hints); assertThat(requestParts.size()).isEqualTo(2); Part part = requestParts.getFirst("resource"); - boolean condition1 = part instanceof FilePart; - assertThat(condition1).isTrue(); + assertThat(part instanceof FilePart).isTrue(); assertThat(((FilePart) part).filename()).isEqualTo("spring.jpg"); assertThat(part.headers().getContentLength()).isEqualTo(logo.getFile().length()); part = requestParts.getFirst("buffers"); - boolean condition = part instanceof FilePart; - assertThat(condition).isTrue(); + assertThat(part instanceof FilePart).isTrue(); assertThat(((FilePart) part).filename()).isEqualTo("buffers.jpg"); assertThat(part.headers().getContentLength()).isEqualTo(logo.getFile().length()); } - private MultiValueMap parse(Map hints) { - MediaType contentType = this.response.getHeaders().getContentType(); + static MultiValueMap parse(MockServerHttpResponse response, Map hints) { + MediaType contentType = response.getHeaders().getContentType(); assertThat(contentType.getParameter("boundary")).as("No boundary found").isNotNull(); // see if Synchronoss NIO Multipart can read what we wrote @@ -299,7 +293,7 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTests { MockServerHttpRequest request = MockServerHttpRequest.post("/") .contentType(MediaType.parseMediaType(contentType.toString())) - .body(this.response.getBody()); + .body(response.getBody()); ResolvableType elementType = ResolvableType.forClassWithGenerics( MultiValueMap.class, String.class, Part.class); diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/PartHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/PartHttpMessageWriterTests.java new file mode 100644 index 0000000000..c519b83d82 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/PartHttpMessageWriterTests.java @@ -0,0 +1,119 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.http.codec.multipart; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Collections; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; + +import org.springframework.core.ResolvableType; +import org.springframework.core.codec.StringDecoder; +import org.springframework.core.testfixture.io.buffer.AbstractLeakCheckingTests; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.util.MultiValueMap; +import org.springframework.web.testfixture.http.server.reactive.MockServerHttpResponse; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.springframework.http.codec.multipart.MultipartHttpMessageWriterTests.parse; + +/** + * Unit tests for {@link PartHttpMessageWriter}. + * + * @author Rossen Stoyanchev + * @since 5.3 + */ +public class PartHttpMessageWriterTests extends AbstractLeakCheckingTests { + + private final PartHttpMessageWriter writer = new PartHttpMessageWriter(); + + private final MockServerHttpResponse response = new MockServerHttpResponse(this.bufferFactory); + + + @Test + public void canWrite() { + assertThat(this.writer.canWrite( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Object.class), + MediaType.MULTIPART_FORM_DATA)).isTrue(); + assertThat(this.writer.canWrite( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class), + MediaType.MULTIPART_FORM_DATA)).isTrue(); + assertThat(this.writer.canWrite( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Object.class), + MediaType.MULTIPART_MIXED)).isTrue(); + assertThat(this.writer.canWrite( + ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Object.class), + MediaType.MULTIPART_RELATED)).isTrue(); + + assertThat(this.writer.canWrite( + ResolvableType.forClassWithGenerics(Map.class, String.class, Object.class), + MediaType.MULTIPART_FORM_DATA)).isFalse(); + } + + @Test + void write() { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + Part textPart = mock(Part.class); + given(textPart.name()).willReturn("text part"); + given(textPart.headers()).willReturn(headers); + given(textPart.content()).willReturn(Flux.just( + this.bufferFactory.wrap("text1".getBytes(StandardCharsets.UTF_8)), + this.bufferFactory.wrap("text2".getBytes(StandardCharsets.UTF_8)))); + + FilePart filePart = mock(FilePart.class); + given(filePart.name()).willReturn("file part"); + given(filePart.headers()).willReturn(new HttpHeaders()); + given(filePart.filename()).willReturn("file.txt"); + given(filePart.content()).willReturn(Flux.just( + this.bufferFactory.wrap("Aa".getBytes(StandardCharsets.UTF_8)), + this.bufferFactory.wrap("Bb".getBytes(StandardCharsets.UTF_8)), + this.bufferFactory.wrap("Cc".getBytes(StandardCharsets.UTF_8)) + )); + + Map hints = Collections.emptyMap(); + this.writer.write(Flux.just(textPart, filePart), null, MediaType.MULTIPART_FORM_DATA, this.response, hints) + .block(Duration.ofSeconds(5)); + + MultiValueMap requestParts = parse(this.response, hints); + assertThat(requestParts.size()).isEqualTo(2); + + Part part = requestParts.getFirst("text part"); + assertThat(part.name()).isEqualTo("text part"); + assertThat(part.headers().getContentType()).isEqualTo(MediaType.TEXT_PLAIN); + String value = decodeToString(part); + assertThat(value).isEqualTo("text1text2"); + + part = requestParts.getFirst("file part"); + assertThat(part.name()).isEqualTo("file part"); + assertThat(((FilePart) part).filename()).isEqualTo("file.txt"); + assertThat(decodeToString(part)).isEqualTo("AaBbCc"); + } + + @SuppressWarnings("ConstantConditions") + private String decodeToString(Part part) { + return StringDecoder.textPlainOnly().decodeToMono(part.content(), + ResolvableType.forClass(String.class), MediaType.TEXT_PLAIN, + Collections.emptyMap()).block(Duration.ZERO); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/codec/support/ServerCodecConfigurerTests.java b/spring-web/src/test/java/org/springframework/http/codec/support/ServerCodecConfigurerTests.java index 3a3ab56515..8da665c7f6 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/support/ServerCodecConfigurerTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/support/ServerCodecConfigurerTests.java @@ -57,6 +57,7 @@ import org.springframework.http.codec.json.Jackson2JsonEncoder; import org.springframework.http.codec.json.Jackson2SmileDecoder; import org.springframework.http.codec.json.Jackson2SmileEncoder; import org.springframework.http.codec.multipart.MultipartHttpMessageReader; +import org.springframework.http.codec.multipart.PartHttpMessageWriter; import org.springframework.http.codec.multipart.SynchronossPartHttpMessageReader; import org.springframework.http.codec.protobuf.ProtobufDecoder; import org.springframework.http.codec.protobuf.ProtobufHttpMessageWriter; @@ -102,7 +103,7 @@ public class ServerCodecConfigurerTests { @Test public void defaultWriters() { List> writers = this.configurer.getWriters(); - assertThat(writers.size()).isEqualTo(12); + assertThat(writers.size()).isEqualTo(13); assertThat(getNextEncoder(writers).getClass()).isEqualTo(ByteArrayEncoder.class); assertThat(getNextEncoder(writers).getClass()).isEqualTo(ByteBufferEncoder.class); assertThat(getNextEncoder(writers).getClass()).isEqualTo(DataBufferEncoder.class); @@ -110,6 +111,7 @@ public class ServerCodecConfigurerTests { assertThat(writers.get(index.getAndIncrement()).getClass()).isEqualTo(ResourceHttpMessageWriter.class); assertStringEncoder(getNextEncoder(writers), true); assertThat(writers.get(index.getAndIncrement()).getClass()).isEqualTo(ProtobufHttpMessageWriter.class); + assertThat(writers.get(this.index.getAndIncrement()).getClass()).isEqualTo(PartHttpMessageWriter.class); assertThat(getNextEncoder(writers).getClass()).isEqualTo(Jackson2JsonEncoder.class); assertThat(getNextEncoder(writers).getClass()).isEqualTo(Jackson2SmileEncoder.class); assertThat(getNextEncoder(writers).getClass()).isEqualTo(Jaxb2XmlEncoder.class); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/config/WebFluxConfigurationSupportTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/config/WebFluxConfigurationSupportTests.java index 5ed8aa49cc..059b7779e7 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/config/WebFluxConfigurationSupportTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/config/WebFluxConfigurationSupportTests.java @@ -207,7 +207,7 @@ public class WebFluxConfigurationSupportTests { assertThat(handler.getOrder()).isEqualTo(0); List> writers = handler.getMessageWriters(); - assertThat(writers.size()).isEqualTo(12); + assertThat(writers.size()).isEqualTo(13); assertHasMessageWriter(writers, forClass(byte[].class), APPLICATION_OCTET_STREAM); assertHasMessageWriter(writers, forClass(ByteBuffer.class), APPLICATION_OCTET_STREAM); @@ -235,7 +235,7 @@ public class WebFluxConfigurationSupportTests { assertThat(handler.getOrder()).isEqualTo(100); List> writers = handler.getMessageWriters(); - assertThat(writers.size()).isEqualTo(12); + assertThat(writers.size()).isEqualTo(13); assertHasMessageWriter(writers, forClass(byte[].class), APPLICATION_OCTET_STREAM); assertHasMessageWriter(writers, forClass(ByteBuffer.class), APPLICATION_OCTET_STREAM);