From 30af01fd4e9e35c21528e23a64889a1e151cac1e Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Fri, 14 Jun 2019 10:31:17 +0200 Subject: [PATCH] Use DataBufferUtils.write in DefaultFilePart.transferTo This commit makes sure that in DefaultMultipartMessageReader's DefaultFilePart, the file is not closed before all bytes are written, by using DataBufferUtils.write (see c1b6885191d6a50347aeaa14da994f0db88f26fe). The commit also improves on the logging of the DefaultMultipartMessageReader. Closes gh-23130 --- .../DefaultMultipartMessageReader.java | 54 ++++--------------- .../annotation/MultipartIntegrationTests.java | 48 +++++++++++++++++ 2 files changed, 59 insertions(+), 43 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/DefaultMultipartMessageReader.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/DefaultMultipartMessageReader.java index ca1c3e2566..95aa0fa30a 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/DefaultMultipartMessageReader.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/DefaultMultipartMessageReader.java @@ -16,14 +16,9 @@ package org.springframework.http.codec.multipart; -import java.io.IOException; -import java.nio.channels.AsynchronousFileChannel; -import java.nio.channels.Channel; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import java.nio.file.OpenOption; import java.nio.file.Path; -import java.nio.file.StandardOpenOption; import java.util.Collections; import java.util.List; import java.util.Map; @@ -100,10 +95,6 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement return Flux.error(new CodecException("No multipart boundary found in Content-Type: \"" + message.getHeaders().getContentType() + "\"")); } - if (logger.isTraceEnabled()) { - logger.trace("Boundary: " + toString(boundary)); - } - byte[] boundaryNeedle = concat(BOUNDARY_PREFIX, boundary); Flux body = skipUntilFirstBoundary(message.getBody(), boundary); @@ -148,8 +139,10 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement DataBuffer slice = dataBuffer.retainedSlice(endIdx + 1, length); DataBufferUtils.release(dataBuffer); if (logger.isTraceEnabled()) { - logger.trace("Found first boundary at " + endIdx + " in " + toString(dataBuffer)); - } + logger.trace( + "Found last byte of first boundary (" + toString(boundary) + + ") at " + endIdx); + } return Mono.just(slice); } else { @@ -188,14 +181,14 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement } } - if (logger.isTraceEnabled()) { - logger.trace("Part data: " + toString(dataBuffer)); - } int endIdx = HEADER_MATCHER.match(dataBuffer); HttpHeaders headers; DataBuffer body; if (endIdx > 0) { + if (logger.isTraceEnabled()) { + logger.trace("Found last byte of part header at " + endIdx ); + } readPosition = dataBuffer.readPosition(); int headersLength = endIdx + 1 - (readPosition + HEADER_BODY_SEPARATOR.length); DataBuffer headersBuffer = dataBuffer.retainedSlice(readPosition, headersLength); @@ -204,6 +197,9 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement headers = toHeaders(headersBuffer); } else { + if (logger.isTraceEnabled()) { + logger.trace("No header found"); + } headers = new HttpHeaders(); body = DataBufferUtils.retain(dataBuffer); } @@ -252,16 +248,6 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement return result; } - - private static String toString(DataBuffer dataBuffer) { - byte[] bytes = new byte[dataBuffer.readableByteCount()]; - int j = 0; - for (int i = dataBuffer.readPosition(); i < dataBuffer.writePosition(); i++) { - bytes[j++] = dataBuffer.getByte(i); - } - return toString(bytes); - } - private static String toString(byte[] bytes) { StringBuilder builder = new StringBuilder(); for (byte b : bytes) { @@ -368,10 +354,6 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement private static class DefaultFilePart extends DefaultPart implements FilePart { - private static final OpenOption[] FILE_CHANNEL_OPTIONS = - {StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE}; - - public DefaultFilePart(HttpHeaders headers, DataBuffer body) { super(headers, body); } @@ -385,23 +367,9 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement @Override public Mono transferTo(Path dest) { - return Mono.using(() -> AsynchronousFileChannel.open(dest, FILE_CHANNEL_OPTIONS), - this::writeBody, this::close); - } - - private Mono writeBody(AsynchronousFileChannel channel) { - return DataBufferUtils.write(content(), channel) - .map(DataBufferUtils::release) - .then(); + return DataBufferUtils.write(content(), dest); } - private void close(Channel channel) { - try { - channel.close(); - } - catch (IOException ignore) { - } - } } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java index 7a6b102b3e..be7dc30f42 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java @@ -16,6 +16,10 @@ package org.springframework.web.reactive.result.method.annotation; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.List; import java.util.stream.Collectors; @@ -31,6 +35,7 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; import org.springframework.http.HttpEntity; import org.springframework.http.HttpStatus; import org.springframework.http.client.MultipartBodyBuilder; @@ -145,6 +150,34 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes .verifyComplete(); } + @Test + public void transferTo() { + Flux result = webClient + .post() + .uri("/transferTo") + .syncBody(generateBody()) + .retrieve() + .bodyToFlux(String.class); + + StepVerifier.create(result) + .consumeNextWith(filename -> verifyContents(Paths.get(filename), new ClassPathResource("foo.txt", MultipartHttpMessageReader.class))) + .consumeNextWith(filename -> verifyContents(Paths.get(filename), new ClassPathResource("logo.png", getClass()))) + .verifyComplete(); + + } + + private static void verifyContents(Path tempFile, Resource resource) { + try { + byte[] tempBytes = Files.readAllBytes(tempFile); + byte[] resourceBytes = Files.readAllBytes(resource.getFile().toPath()); + assertThat(tempBytes).isEqualTo(resourceBytes); + } + catch (IOException ex) { + throw new AssertionError(ex); + } + } + + @Test public void modelAttribute() { Mono result = webClient @@ -217,6 +250,21 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes return partFluxDescription(Flux.from(parts)); } + @PostMapping("/transferTo") + Flux transferTo(@RequestPart("fileParts") Flux parts) { + return parts.flatMap(filePart -> { + try { + Path tempFile = Files.createTempFile("MultipartIntegrationTests", filePart.filename()); + return filePart.transferTo(tempFile) + .then(Mono.just(tempFile.toString() + "\n")); + + } + catch (IOException e) { + return Mono.error(e); + } + }); + } + @PostMapping("/modelAttribute") String modelAttribute(@ModelAttribute FormBean formBean) { return formBean.toString();