Browse Source

Use DataBufferUtils.write in DefaultFilePart.transferTo

This commit makes sure that in DefaultMultipartMessageReader's
DefaultFilePart, the file is not closed before all bytes are written,
by using DataBufferUtils.write (see c1b6885191d6a50347aeaa14da994f0db88f26fe).

The commit also improves on the logging of the
DefaultMultipartMessageReader.

Closes gh-23130
pull/23141/head
Arjen Poutsma 6 years ago
parent
commit
30af01fd4e
  1. 54
      spring-web/src/main/java/org/springframework/http/codec/multipart/DefaultMultipartMessageReader.java
  2. 48
      spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java

54
spring-web/src/main/java/org/springframework/http/codec/multipart/DefaultMultipartMessageReader.java

@ -16,14 +16,9 @@
package org.springframework.http.codec.multipart; package org.springframework.http.codec.multipart;
import java.io.IOException;
import java.nio.channels.AsynchronousFileChannel;
import java.nio.channels.Channel;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.OpenOption;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -100,10 +95,6 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
return Flux.error(new CodecException("No multipart boundary found in Content-Type: \"" + return Flux.error(new CodecException("No multipart boundary found in Content-Type: \"" +
message.getHeaders().getContentType() + "\"")); message.getHeaders().getContentType() + "\""));
} }
if (logger.isTraceEnabled()) {
logger.trace("Boundary: " + toString(boundary));
}
byte[] boundaryNeedle = concat(BOUNDARY_PREFIX, boundary); byte[] boundaryNeedle = concat(BOUNDARY_PREFIX, boundary);
Flux<DataBuffer> body = skipUntilFirstBoundary(message.getBody(), boundary); Flux<DataBuffer> body = skipUntilFirstBoundary(message.getBody(), boundary);
@ -148,8 +139,10 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
DataBuffer slice = dataBuffer.retainedSlice(endIdx + 1, length); DataBuffer slice = dataBuffer.retainedSlice(endIdx + 1, length);
DataBufferUtils.release(dataBuffer); DataBufferUtils.release(dataBuffer);
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Found first boundary at " + endIdx + " in " + toString(dataBuffer)); logger.trace(
} "Found last byte of first boundary (" + toString(boundary)
+ ") at " + endIdx);
}
return Mono.just(slice); return Mono.just(slice);
} }
else { else {
@ -188,14 +181,14 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
} }
} }
if (logger.isTraceEnabled()) {
logger.trace("Part data: " + toString(dataBuffer));
}
int endIdx = HEADER_MATCHER.match(dataBuffer); int endIdx = HEADER_MATCHER.match(dataBuffer);
HttpHeaders headers; HttpHeaders headers;
DataBuffer body; DataBuffer body;
if (endIdx > 0) { if (endIdx > 0) {
if (logger.isTraceEnabled()) {
logger.trace("Found last byte of part header at " + endIdx );
}
readPosition = dataBuffer.readPosition(); readPosition = dataBuffer.readPosition();
int headersLength = endIdx + 1 - (readPosition + HEADER_BODY_SEPARATOR.length); int headersLength = endIdx + 1 - (readPosition + HEADER_BODY_SEPARATOR.length);
DataBuffer headersBuffer = dataBuffer.retainedSlice(readPosition, headersLength); DataBuffer headersBuffer = dataBuffer.retainedSlice(readPosition, headersLength);
@ -204,6 +197,9 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
headers = toHeaders(headersBuffer); headers = toHeaders(headersBuffer);
} }
else { else {
if (logger.isTraceEnabled()) {
logger.trace("No header found");
}
headers = new HttpHeaders(); headers = new HttpHeaders();
body = DataBufferUtils.retain(dataBuffer); body = DataBufferUtils.retain(dataBuffer);
} }
@ -252,16 +248,6 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
return result; return result;
} }
private static String toString(DataBuffer dataBuffer) {
byte[] bytes = new byte[dataBuffer.readableByteCount()];
int j = 0;
for (int i = dataBuffer.readPosition(); i < dataBuffer.writePosition(); i++) {
bytes[j++] = dataBuffer.getByte(i);
}
return toString(bytes);
}
private static String toString(byte[] bytes) { private static String toString(byte[] bytes) {
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
for (byte b : bytes) { for (byte b : bytes) {
@ -368,10 +354,6 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
private static class DefaultFilePart extends DefaultPart implements FilePart { private static class DefaultFilePart extends DefaultPart implements FilePart {
private static final OpenOption[] FILE_CHANNEL_OPTIONS =
{StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE};
public DefaultFilePart(HttpHeaders headers, DataBuffer body) { public DefaultFilePart(HttpHeaders headers, DataBuffer body) {
super(headers, body); super(headers, body);
} }
@ -385,23 +367,9 @@ public class DefaultMultipartMessageReader extends LoggingCodecSupport implement
@Override @Override
public Mono<Void> transferTo(Path dest) { public Mono<Void> transferTo(Path dest) {
return Mono.using(() -> AsynchronousFileChannel.open(dest, FILE_CHANNEL_OPTIONS), return DataBufferUtils.write(content(), dest);
this::writeBody, this::close);
}
private Mono<Void> writeBody(AsynchronousFileChannel channel) {
return DataBufferUtils.write(content(), channel)
.map(DataBufferUtils::release)
.then();
} }
private void close(Channel channel) {
try {
channel.close();
}
catch (IOException ignore) {
}
}
} }
} }

48
spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/MultipartIntegrationTests.java

@ -16,6 +16,10 @@
package org.springframework.web.reactive.result.method.annotation; package org.springframework.web.reactive.result.method.annotation;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -31,6 +35,7 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpEntity; import org.springframework.http.HttpEntity;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.client.MultipartBodyBuilder; import org.springframework.http.client.MultipartBodyBuilder;
@ -145,6 +150,34 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes
.verifyComplete(); .verifyComplete();
} }
@Test
public void transferTo() {
Flux<String> result = webClient
.post()
.uri("/transferTo")
.syncBody(generateBody())
.retrieve()
.bodyToFlux(String.class);
StepVerifier.create(result)
.consumeNextWith(filename -> verifyContents(Paths.get(filename), new ClassPathResource("foo.txt", MultipartHttpMessageReader.class)))
.consumeNextWith(filename -> verifyContents(Paths.get(filename), new ClassPathResource("logo.png", getClass())))
.verifyComplete();
}
private static void verifyContents(Path tempFile, Resource resource) {
try {
byte[] tempBytes = Files.readAllBytes(tempFile);
byte[] resourceBytes = Files.readAllBytes(resource.getFile().toPath());
assertThat(tempBytes).isEqualTo(resourceBytes);
}
catch (IOException ex) {
throw new AssertionError(ex);
}
}
@Test @Test
public void modelAttribute() { public void modelAttribute() {
Mono<String> result = webClient Mono<String> result = webClient
@ -217,6 +250,21 @@ public class MultipartIntegrationTests extends AbstractHttpHandlerIntegrationTes
return partFluxDescription(Flux.from(parts)); return partFluxDescription(Flux.from(parts));
} }
@PostMapping("/transferTo")
Flux<String> transferTo(@RequestPart("fileParts") Flux<FilePart> parts) {
return parts.flatMap(filePart -> {
try {
Path tempFile = Files.createTempFile("MultipartIntegrationTests", filePart.filename());
return filePart.transferTo(tempFile)
.then(Mono.just(tempFile.toString() + "\n"));
}
catch (IOException e) {
return Mono.error(e);
}
});
}
@PostMapping("/modelAttribute") @PostMapping("/modelAttribute")
String modelAttribute(@ModelAttribute FormBean formBean) { String modelAttribute(@ModelAttribute FormBean formBean) {
return formBean.toString(); return formBean.toString();

Loading…
Cancel
Save