diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java index a9e9088862..8fd1ac0ffe 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java @@ -144,10 +144,8 @@ public abstract class DataBufferUtils { channel -> Flux.create(sink -> { ReadCompletionHandler handler = new ReadCompletionHandler(channel, sink, position, bufferFactory, bufferSize); - sink.onDispose(handler::dispose); - DataBuffer dataBuffer = bufferFactory.allocateBuffer(bufferSize); - ByteBuffer byteBuffer = dataBuffer.asByteBuffer(0, bufferSize); - channel.read(byteBuffer, position, dataBuffer, handler); + sink.onCancel(handler::cancel); + sink.onRequest(handler::request); }), channel -> { // Do not close channel from here, rather wait for the current read callback @@ -654,7 +652,9 @@ public abstract class DataBufferUtils { private final AtomicLong position; - private final AtomicBoolean disposed = new AtomicBoolean(); + private final AtomicBoolean reading = new AtomicBoolean(); + + private final AtomicBoolean canceled = new AtomicBoolean(); public ReadCompletionHandler(AsynchronousFileChannel channel, FluxSink sink, long position, DataBufferFactory dataBufferFactory, int bufferSize) { @@ -666,43 +666,62 @@ public abstract class DataBufferUtils { this.bufferSize = bufferSize; } + public void read() { + if (this.sink.requestedFromDownstream() > 0 && this.reading.compareAndSet(false, true)) { + DataBuffer dataBuffer = this.dataBufferFactory.allocateBuffer(this.bufferSize); + ByteBuffer byteBuffer = dataBuffer.asByteBuffer(0, this.bufferSize); + this.channel.read(byteBuffer, this.position.get(), dataBuffer, this); + } + } + @Override public void completed(Integer read, DataBuffer dataBuffer) { - if (read != -1 && !this.disposed.get()) { - long pos = this.position.addAndGet(read); - dataBuffer.writePosition(read); - this.sink.next(dataBuffer); - // onNext may have led to onCancel (e.g. downstream takeUntil) - if (this.disposed.get()) { - complete(); + this.reading.set(false); + if (!isCanceled()) { + if (read != -1) { + this.position.addAndGet(read); + dataBuffer.writePosition(read); + this.sink.next(dataBuffer); + read(); } else { - DataBuffer newDataBuffer = this.dataBufferFactory.allocateBuffer(this.bufferSize); - ByteBuffer newByteBuffer = newDataBuffer.asByteBuffer(0, this.bufferSize); - this.channel.read(newByteBuffer, pos, newDataBuffer, this); + release(dataBuffer); + closeChannel(this.channel); + this.sink.complete(); } } else { release(dataBuffer); - complete(); + closeChannel(this.channel); } } - private void complete() { - this.sink.complete(); - closeChannel(this.channel); - } - @Override public void failed(Throwable exc, DataBuffer dataBuffer) { + this.reading.set(false); release(dataBuffer); - this.sink.error(exc); closeChannel(this.channel); + if (!isCanceled()) { + this.sink.error(exc); + } + } + + public void request(long n) { + read(); } - public void dispose() { - this.disposed.set(true); + public void cancel() { + if (this.canceled.compareAndSet(false, true)) { + if (!this.reading.get()) { + closeChannel(this.channel); + } + } } + + private boolean isCanceled() { + return this.canceled.get(); + } + } diff --git a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java index 5730132cf1..f5245b61fa 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java +++ b/spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java @@ -27,7 +27,6 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; -import java.util.stream.Collectors; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -249,8 +248,8 @@ public class MultipartHttpMessageWriter extends LoggingCodecSupport private Flux encodePartValues( byte[] boundary, String name, List values, DataBufferFactory bufferFactory) { - return Flux.concat(values.stream().map(v -> - encodePart(boundary, name, v, bufferFactory)).collect(Collectors.toList())); + return Flux.fromIterable(values) + .concatMap(value -> encodePart(boundary, name, value, bufferFactory)); } @SuppressWarnings("unchecked") diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java index f64d597781..f0f7db6b9c 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/WebClientIntegrationTests.java @@ -568,12 +568,13 @@ class WebClientIntegrationTests { byte[] expected = Files.readAllBytes(resource.getFile().toPath()); Flux body = DataBufferUtils.read(resource, new DefaultDataBufferFactory(), 4096); - this.webClient.post() + Mono result = this.webClient.post() .uri("/") .body(body, DataBuffer.class) .retrieve() - .bodyToMono(Void.class) - .block(Duration.ofSeconds(5)); + .bodyToMono(Void.class); + + StepVerifier.create(result).verifyComplete(); expectRequest(request -> { ByteArrayOutputStream actual = new ByteArrayOutputStream();