diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/DefaultMultipartMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/DefaultMultipartMessageReader.java index ca1c3e2566..95aa0fa30a 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/DefaultMultipartMessageReader.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/DefaultMultipartMessageReader.java @@ -16,14 +16,9 @@ package org.springframework.http.codec.multipart; -import java.io.IOException; -import java.nio.channels.AsynchronousFileChannel; -import java.nio.channels.Channel; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import java.nio.file.OpenOption; import java.nio.file.Path; -import java.nio.file.StandardOpenOption; import java.util.Collections; import java.util.List; import java.util.Map; @@ -100,10 +95,6 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement return Flux.error(new CodecException("No multipart boundary found in Content-Type: \"" + message.getHeaders().getContentType() + "\"")); } - if (logger.isTraceEnabled()) { - logger.trace("Boundary: " + toString(boundary)); - } - byte[] boundaryNeedle = concat(BOUNDARY_PREFIX, boundary); Flux body = skipUntilFirstBoundary(message.getBody(), boundary); @@ -148,8 +139,10 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement DataBuffer slice = dataBuffer.retainedSlice(endIdx + 1, length); DataBufferUtils.release(dataBuffer); if (logger.isTraceEnabled()) { - logger.trace("Found first boundary at " + endIdx + " in " + toString(dataBuffer)); - } + logger.trace( + "Found last byte of first boundary (" + toString(boundary) + + ") at " + endIdx); + } return Mono.just(slice); } else { @@ -188,14 +181,14 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement } } - if (logger.isTraceEnabled()) { - logger.trace("Part data: " + toString(dataBuffer)); - } int endIdx = HEADER_MATCHER.match(dataBuffer); HttpHeaders headers; DataBuffer body; if (endIdx > 0) { + if (logger.isTraceEnabled()) { + logger.trace("Found last byte of part header at " + endIdx ); + } readPosition = dataBuffer.readPosition(); int headersLength = endIdx + 1 - (readPosition + HEADER_BODY_SEPARATOR.length); DataBuffer headersBuffer = dataBuffer.retainedSlice(readPosition, headersLength); @@ -204,6 +197,9 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement headers = toHeaders(headersBuffer); } else { + if (logger.isTraceEnabled()) { + logger.trace("No header found"); + } headers = new HttpHeaders(); body = DataBufferUtils.retain(dataBuffer); } @@ -252,16 +248,6 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement return result; } - - private static String toString(DataBuffer dataBuffer) { - byte[] bytes = new byte[dataBuffer.readableByteCount()]; - int j = 0; - for (int i = dataBuffer.readPosition(); i < dataBuffer.writePosition(); i++) { - bytes[j++] = dataBuffer.getByte(i); - } - return toString(bytes); - } - private static String toString(byte[] bytes) { StringBuilder builder = new StringBuilder(); for (byte b : bytes) { @@ -368,10 +354,6 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement private static class DefaultFilePart extends DefaultPart implements FilePart { - private static final OpenOption[] FILE_CHANNEL_OPTIONS = - {StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE}; - - public DefaultFilePart(HttpHeaders headers, DataBuffer body) { super(headers, body); } @@ -385,23 +367,9 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement @Override public Mono transferTo(Path dest) { - return Mono.using(() -> AsynchronousFileChannel.open(dest, FILE_CHANNEL_OPTIONS), - this::writeBody, this::close); - } - - private Mono writeBody(AsynchronousFileChannel channel) { - return DataBufferUtils.write(content(), channel) - .map(DataBufferUtils::release) - .then(); + return DataBufferUtils.write(content(), dest); } - private void close(Channel channel) { - try { - channel.close(); - } - catch (IOException ignore) { - } - } } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java index 7a6b102b3e..be7dc30f42 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java @@ -16,6 +16,10 @@ package org.springframework.web.reactive.result.method.annotation; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.List; import java.util.stream.Collectors; @@ -31,6 +35,7 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; import org.springframework.http.HttpEntity; import org.springframework.http.HttpStatus; import org.springframework.http.client.MultipartBodyBuilder; @@ -145,6 +150,34 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes .verifyComplete(); } + @Test + public void transferTo() { + Flux result = webClient + .post() + .uri("/transferTo") + .syncBody(generateBody()) + .retrieve() + .bodyToFlux(String.class); + + StepVerifier.create(result) + .consumeNextWith(filename -> verifyContents(Paths.get(filename), new ClassPathResource("foo.txt", MultipartHttpMessageReader.class))) + .consumeNextWith(filename -> verifyContents(Paths.get(filename), new ClassPathResource("logo.png", getClass()))) + .verifyComplete(); + + } + + private static void verifyContents(Path tempFile, Resource resource) { + try { + byte[] tempBytes = Files.readAllBytes(tempFile); + byte[] resourceBytes = Files.readAllBytes(resource.getFile().toPath()); + assertThat(tempBytes).isEqualTo(resourceBytes); + } + catch (IOException ex) { + throw new AssertionError(ex); + } + } + + @Test public void modelAttribute() { Mono result = webClient @@ -217,6 +250,21 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes return partFluxDescription(Flux.from(parts)); } + @PostMapping("/transferTo") + Flux transferTo(@RequestPart("fileParts") Flux parts) { + return parts.flatMap(filePart -> { + try { + Path tempFile = Files.createTempFile("MultipartIntegrationTests", filePart.filename()); + return filePart.transferTo(tempFile) + .then(Mono.just(tempFile.toString() + "\n")); + + } + catch (IOException e) { + return Mono.error(e); + } + }); + } + @PostMapping("/modelAttribute") String modelAttribute(@ModelAttribute FormBean formBean) { return formBean.toString();