From 952315c333abf3ab587731b567d809f0a20da5f7 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Tue, 11 Sep 2018 13:24:03 +0200 Subject: [PATCH] DataBufferUtils does not release DataBuffer on error cases This commit makes sure that in DataBufferUtils.write, any received data buffers are returned as part of the returned flux, even when an error occurs or is received. Issue: SPR-16782 (cherry picked from commit 1a0522b8057dead47c90404d629c7597d0787dce) --- .../core/io/buffer/DataBufferUtils.java | 57 ++++- .../core/io/buffer/DataBufferUtilsTests.java | 206 ++++++++++++++++-- 2 files changed, 239 insertions(+), 24 deletions(-) diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java index 93510a67bc..0cce689ade 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java @@ -31,6 +31,7 @@ import java.nio.file.StandardOpenOption; import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.IntPredicate; @@ -336,6 +337,7 @@ public abstract class DataBufferUtils { sink.next(dataBuffer); } catch (IOException ex) { + sink.next(dataBuffer); sink.error(ex); } @@ -355,6 +357,26 @@ public abstract class DataBufferUtils { * @param channel the channel to write to * @return a flux containing the same buffers as in {@code source}, that starts the writing * process when subscribed to, and that publishes any writing errors and the completion signal + * @since 5.0.10 + */ + public static Flux write( + Publisher source, AsynchronousFileChannel channel) { + return write(source, channel, 0); + } + + + /** + * Write the given stream of {@link DataBuffer DataBuffers} to the given {@code AsynchronousFileChannel}. + * Does not close the channel when the flux is terminated, and does + * not {@linkplain #release(DataBuffer) release} the data buffers in the + * source. If releasing is required, then subscribe to the returned {@code Flux} with a + * {@link #releaseConsumer()}. + *

Note that the writing process does not start until the returned {@code Flux} is subscribed to. + * @param source the stream of data buffers to be written + * @param channel the channel to write to + * @param position the file position at which the write is to begin; must be non-negative + * @return a flux containing the same buffers as in {@code source}, that starts the writing + * process when subscribed to, and that publishes any writing errors and the completion signal */ public static Flux write( Publisher source, AsynchronousFileChannel channel, long position) { @@ -610,10 +632,11 @@ public abstract class DataBufferUtils { private final AtomicBoolean completed = new AtomicBoolean(); + private final AtomicReference error = new AtomicReference<>(); + private final AtomicLong position; - @Nullable - private DataBuffer dataBuffer; + private final AtomicReference dataBuffer = new AtomicReference<>(); public AsynchronousFileChannelWriteCompletionHandler( FluxSink sink, AsynchronousFileChannel channel, long position) { @@ -630,21 +653,27 @@ public abstract class DataBufferUtils { @Override protected void hookOnNext(DataBuffer value) { - this.dataBuffer = value; + if (!this.dataBuffer.compareAndSet(null, value)) { + throw new IllegalStateException(); + } ByteBuffer byteBuffer = value.asByteBuffer(); this.channel.write(byteBuffer, this.position.get(), byteBuffer, this); } @Override protected void hookOnError(Throwable throwable) { - this.sink.error(throwable); + this.error.set(throwable); + + if (this.dataBuffer.get() == null) { + this.sink.error(throwable); + } } @Override protected void hookOnComplete() { this.completed.set(true); - if (this.dataBuffer == null) { + if (this.dataBuffer.get() == null) { this.sink.complete(); } } @@ -656,11 +685,13 @@ public abstract class DataBufferUtils { this.channel.write(byteBuffer, pos, byteBuffer, this); return; } - if (this.dataBuffer != null) { - this.sink.next(this.dataBuffer); - this.dataBuffer = null; + sinkDataBuffer(); + + Throwable throwable = this.error.get(); + if (throwable != null) { + this.sink.error(throwable); } - if (this.completed.get()) { + else if (this.completed.get()) { this.sink.complete(); } else { @@ -670,8 +701,16 @@ public abstract class DataBufferUtils { @Override public void failed(Throwable exc, ByteBuffer byteBuffer) { + sinkDataBuffer(); this.sink.error(exc); } + + private void sinkDataBuffer() { + DataBuffer dataBuffer = this.dataBuffer.get(); + Assert.state(dataBuffer != null, "DataBuffer should not be null"); + this.sink.next(dataBuffer); + this.dataBuffer.set(null); + } } diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java index 046693267b..1ade1ef9ba 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java @@ -16,10 +16,12 @@ package org.springframework.core.io.buffer; +import java.io.IOException; import java.io.OutputStream; import java.net.URI; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousFileChannel; +import java.nio.channels.CompletionHandler; import java.nio.channels.FileChannel; import java.nio.channels.ReadableByteChannel; import java.nio.channels.WritableByteChannel; @@ -29,7 +31,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.time.Duration; -import java.util.stream.Collectors; +import java.util.concurrent.CountDownLatch; import io.netty.buffer.ByteBuf; import org.junit.Test; @@ -160,9 +162,7 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { .expectComplete() .verify(Duration.ofSeconds(5)); - String result = Files.readAllLines(tempFile) - .stream() - .collect(Collectors.joining()); + String result = String.join("", Files.readAllLines(tempFile)); assertEquals("foobarbazqux", result); os.close(); @@ -188,14 +188,60 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { .expectComplete() .verify(Duration.ofSeconds(5)); - String result = Files.readAllLines(tempFile) - .stream() - .collect(Collectors.joining()); + String result = String.join("", Files.readAllLines(tempFile)); assertEquals("foobarbazqux", result); channel.close(); } + @Test + public void writeWritableByteChannelErrorInFlux() throws Exception { + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + Flux flux = Flux.just(foo, bar).concatWith(Flux.error(new RuntimeException())); + + Path tempFile = Files.createTempFile("DataBufferUtilsTests", null); + WritableByteChannel channel = Files.newByteChannel(tempFile, StandardOpenOption.WRITE); + + Flux writeResult = DataBufferUtils.write(flux, channel); + StepVerifier.create(writeResult) + .consumeNextWith(stringConsumer("foo")) + .consumeNextWith(stringConsumer("bar")) + .expectError() + .verify(Duration.ofSeconds(5)); + + String result = String.join("", Files.readAllLines(tempFile)); + + assertEquals("foobar", result); + channel.close(); + } + + @Test + public void writeWritableByteChannelErrorInWrite() throws Exception { + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + Flux flux = Flux.just(foo, bar); + + WritableByteChannel channel = mock(WritableByteChannel.class); + when(channel.write(any())) + .thenAnswer(invocation -> { + ByteBuffer buffer = invocation.getArgument(0); + int written = buffer.remaining(); + buffer.position(buffer.limit()); + return written; + }) + .thenThrow(new IOException()); + + Flux writeResult = DataBufferUtils.write(flux, channel); + StepVerifier.create(writeResult) + .consumeNextWith(stringConsumer("foo")) + .consumeNextWith(stringConsumer("bar")) + .expectError(IOException.class) + .verify(); + + channel.close(); + } + @Test public void writeAsynchronousFileChannel() throws Exception { DataBuffer foo = stringBuffer("foo"); @@ -208,7 +254,7 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { AsynchronousFileChannel channel = AsynchronousFileChannel.open(tempFile, StandardOpenOption.WRITE); - Flux writeResult = DataBufferUtils.write(flux, channel, 0); + Flux writeResult = DataBufferUtils.write(flux, channel); StepVerifier.create(writeResult) .consumeNextWith(stringConsumer("foo")) .consumeNextWith(stringConsumer("bar")) @@ -217,14 +263,142 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { .expectComplete() .verify(Duration.ofSeconds(5)); - String result = Files.readAllLines(tempFile) - .stream() - .collect(Collectors.joining()); + String result = String.join("", Files.readAllLines(tempFile)); assertEquals("foobarbazqux", result); channel.close(); } + @Test + public void writeAsynchronousFileChannelErrorInFlux() throws Exception { + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + Flux flux = + Flux.just(foo, bar).concatWith(Mono.error(new RuntimeException())); + + Path tempFile = Files.createTempFile("DataBufferUtilsTests", null); + AsynchronousFileChannel channel = + AsynchronousFileChannel.open(tempFile, StandardOpenOption.WRITE); + + Flux writeResult = DataBufferUtils.write(flux, channel); + StepVerifier.create(writeResult) + .consumeNextWith(stringConsumer("foo")) + .consumeNextWith(stringConsumer("bar")) + .expectError(RuntimeException.class) + .verify(); + + String result = String.join("", Files.readAllLines(tempFile)); + + assertEquals("foobar", result); + channel.close(); + } + + + @Test + public void writeAsynchronousFileChannelErrorInWrite() throws Exception { + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + Flux flux = Flux.just(foo, bar); + + AsynchronousFileChannel channel = mock(AsynchronousFileChannel.class); + doAnswer(invocation -> { + ByteBuffer buffer = invocation.getArgument(0); + long pos = invocation.getArgument(1); + CompletionHandler completionHandler = invocation.getArgument(3); + + assertEquals(0, pos); + + int written = buffer.remaining(); + buffer.position(buffer.limit()); + completionHandler.completed(written, buffer); + + return null; + }) + .doAnswer(invocation -> { + ByteBuffer buffer = invocation.getArgument(0); + CompletionHandler completionHandler = + invocation.getArgument(3); + completionHandler.failed(new IOException(), buffer); + return null; + }) + .when(channel).write(isA(ByteBuffer.class), anyLong(), isA(ByteBuffer.class), + isA(CompletionHandler.class)); + + Flux writeResult = DataBufferUtils.write(flux, channel); + StepVerifier.create(writeResult) + .consumeNextWith(stringConsumer("foo")) + .consumeNextWith(stringConsumer("bar")) + .expectError(IOException.class) + .verify(); + + channel.close(); + } + + @Test + public void readAndWriteByteChannel() throws Exception { + Path source = Paths.get( + DataBufferUtilsTests.class.getResource("DataBufferUtilsTests.txt").toURI()); + Flux sourceFlux = + DataBufferUtils + .readByteChannel(() -> FileChannel.open(source, StandardOpenOption.READ), + this.bufferFactory, 3); + + Path destination = Files.createTempFile("DataBufferUtilsTests", null); + WritableByteChannel channel = Files.newByteChannel(destination, StandardOpenOption.WRITE); + + DataBufferUtils.write(sourceFlux, channel) + .subscribe(DataBufferUtils.releaseConsumer(), + throwable -> fail(throwable.getMessage()), + () -> { + try { + String expected = String.join("", Files.readAllLines(source)); + String result = String.join("", Files.readAllLines(destination)); + + assertEquals(expected, result); + channel.close(); + + } + catch (IOException e) { + fail(e.getMessage()); + } + }); + } + + @Test + public void readAndWriteAsynchronousFileChannel() throws Exception { + Path source = Paths.get( + DataBufferUtilsTests.class.getResource("DataBufferUtilsTests.txt").toURI()); + Flux sourceFlux = DataBufferUtils.readAsynchronousFileChannel( + () -> AsynchronousFileChannel.open(source, StandardOpenOption.READ), + this.bufferFactory, 3); + + Path destination = Files.createTempFile("DataBufferUtilsTests", null); + AsynchronousFileChannel channel = + AsynchronousFileChannel.open(destination, StandardOpenOption.WRITE); + + CountDownLatch latch = new CountDownLatch(1); + + DataBufferUtils.write(sourceFlux, channel) + .subscribe(DataBufferUtils::release, + throwable -> fail(throwable.getMessage()), + () -> { + try { + String expected = String.join("", Files.readAllLines(source)); + String result = String.join("", Files.readAllLines(destination)); + + assertEquals(expected, result); + channel.close(); + latch.countDown(); + + } + catch (IOException e) { + fail(e.getMessage()); + } + }); + + latch.await(); + } + @Test public void takeUntilByteCount() { @@ -314,7 +488,8 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { .thenAnswer(putByte('c')) .thenReturn(-1); - Flux read = DataBufferUtils.readByteChannel(() -> channel, this.bufferFactory, 1); + Flux read = + DataBufferUtils.readByteChannel(() -> channel, this.bufferFactory, 1); StepVerifier.create(read) .consumeNextWith(stringConsumer("a")) @@ -343,9 +518,10 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { StepVerifier.create(result) .consumeNextWith(dataBuffer -> { - assertEquals("foobarbaz", DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8)); - release(dataBuffer); - }) + assertEquals("foobarbaz", + DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8)); + release(dataBuffer); + }) .verifyComplete(); }