diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java index 93510a67bc..0cce689ade 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java @@ -31,6 +31,7 @@ import java.nio.file.StandardOpenOption; import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.IntPredicate; @@ -336,6 +337,7 @@ public abstract class DataBufferUtils { sink.next(dataBuffer); } catch (IOException ex) { + sink.next(dataBuffer); sink.error(ex); } @@ -355,6 +357,26 @@ public abstract class DataBufferUtils { * @param channel the channel to write to * @return a flux containing the same buffers as in {@code source}, that starts the writing * process when subscribed to, and that publishes any writing errors and the completion signal + * @since 5.0.10 + */ + public static Flux write( + Publisher source, AsynchronousFileChannel channel) { + return write(source, channel, 0); + } + + + /** + * Write the given stream of {@link DataBuffer DataBuffers} to the given {@code AsynchronousFileChannel}. + * Does not close the channel when the flux is terminated, and does + * not {@linkplain #release(DataBuffer) release} the data buffers in the + * source. If releasing is required, then subscribe to the returned {@code Flux} with a + * {@link #releaseConsumer()}. + *

Note that the writing process does not start until the returned {@code Flux} is subscribed to. + * @param source the stream of data buffers to be written + * @param channel the channel to write to + * @param position the file position at which the write is to begin; must be non-negative + * @return a flux containing the same buffers as in {@code source}, that starts the writing + * process when subscribed to, and that publishes any writing errors and the completion signal */ public static Flux write( Publisher source, AsynchronousFileChannel channel, long position) { @@ -610,10 +632,11 @@ public abstract class DataBufferUtils { private final AtomicBoolean completed = new AtomicBoolean(); + private final AtomicReference error = new AtomicReference<>(); + private final AtomicLong position; - @Nullable - private DataBuffer dataBuffer; + private final AtomicReference dataBuffer = new AtomicReference<>(); public AsynchronousFileChannelWriteCompletionHandler( FluxSink sink, AsynchronousFileChannel channel, long position) { @@ -630,21 +653,27 @@ public abstract class DataBufferUtils { @Override protected void hookOnNext(DataBuffer value) { - this.dataBuffer = value; + if (!this.dataBuffer.compareAndSet(null, value)) { + throw new IllegalStateException(); + } ByteBuffer byteBuffer = value.asByteBuffer(); this.channel.write(byteBuffer, this.position.get(), byteBuffer, this); } @Override protected void hookOnError(Throwable throwable) { - this.sink.error(throwable); + this.error.set(throwable); + + if (this.dataBuffer.get() == null) { + this.sink.error(throwable); + } } @Override protected void hookOnComplete() { this.completed.set(true); - if (this.dataBuffer == null) { + if (this.dataBuffer.get() == null) { this.sink.complete(); } } @@ -656,11 +685,13 @@ public abstract class DataBufferUtils { this.channel.write(byteBuffer, pos, byteBuffer, this); return; } - if (this.dataBuffer != null) { - this.sink.next(this.dataBuffer); - this.dataBuffer = null; + sinkDataBuffer(); + + Throwable throwable = this.error.get(); + if (throwable != null) { + this.sink.error(throwable); } - if (this.completed.get()) { + else if (this.completed.get()) { this.sink.complete(); } else { @@ -670,8 +701,16 @@ public abstract class DataBufferUtils { @Override public void failed(Throwable exc, ByteBuffer byteBuffer) { + sinkDataBuffer(); this.sink.error(exc); } + + private void sinkDataBuffer() { + DataBuffer dataBuffer = this.dataBuffer.get(); + Assert.state(dataBuffer != null, "DataBuffer should not be null"); + this.sink.next(dataBuffer); + this.dataBuffer.set(null); + } } diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java index 046693267b..1ade1ef9ba 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java @@ -16,10 +16,12 @@ package org.springframework.core.io.buffer; +import java.io.IOException; import java.io.OutputStream; import java.net.URI; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousFileChannel; +import java.nio.channels.CompletionHandler; import java.nio.channels.FileChannel; import java.nio.channels.ReadableByteChannel; import java.nio.channels.WritableByteChannel; @@ -29,7 +31,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.time.Duration; -import java.util.stream.Collectors; +import java.util.concurrent.CountDownLatch; import io.netty.buffer.ByteBuf; import org.junit.Test; @@ -160,9 +162,7 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { .expectComplete() .verify(Duration.ofSeconds(5)); - String result = Files.readAllLines(tempFile) - .stream() - .collect(Collectors.joining()); + String result = String.join("", Files.readAllLines(tempFile)); assertEquals("foobarbazqux", result); os.close(); @@ -188,14 +188,60 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { .expectComplete() .verify(Duration.ofSeconds(5)); - String result = Files.readAllLines(tempFile) - .stream() - .collect(Collectors.joining()); + String result = String.join("", Files.readAllLines(tempFile)); assertEquals("foobarbazqux", result); channel.close(); } + @Test + public void writeWritableByteChannelErrorInFlux() throws Exception { + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + Flux flux = Flux.just(foo, bar).concatWith(Flux.error(new RuntimeException())); + + Path tempFile = Files.createTempFile("DataBufferUtilsTests", null); + WritableByteChannel channel = Files.newByteChannel(tempFile, StandardOpenOption.WRITE); + + Flux writeResult = DataBufferUtils.write(flux, channel); + StepVerifier.create(writeResult) + .consumeNextWith(stringConsumer("foo")) + .consumeNextWith(stringConsumer("bar")) + .expectError() + .verify(Duration.ofSeconds(5)); + + String result = String.join("", Files.readAllLines(tempFile)); + + assertEquals("foobar", result); + channel.close(); + } + + @Test + public void writeWritableByteChannelErrorInWrite() throws Exception { + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + Flux flux = Flux.just(foo, bar); + + WritableByteChannel channel = mock(WritableByteChannel.class); + when(channel.write(any())) + .thenAnswer(invocation -> { + ByteBuffer buffer = invocation.getArgument(0); + int written = buffer.remaining(); + buffer.position(buffer.limit()); + return written; + }) + .thenThrow(new IOException()); + + Flux writeResult = DataBufferUtils.write(flux, channel); + StepVerifier.create(writeResult) + .consumeNextWith(stringConsumer("foo")) + .consumeNextWith(stringConsumer("bar")) + .expectError(IOException.class) + .verify(); + + channel.close(); + } + @Test public void writeAsynchronousFileChannel() throws Exception { DataBuffer foo = stringBuffer("foo"); @@ -208,7 +254,7 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { AsynchronousFileChannel channel = AsynchronousFileChannel.open(tempFile, StandardOpenOption.WRITE); - Flux writeResult = DataBufferUtils.write(flux, channel, 0); + Flux writeResult = DataBufferUtils.write(flux, channel); StepVerifier.create(writeResult) .consumeNextWith(stringConsumer("foo")) .consumeNextWith(stringConsumer("bar")) @@ -217,14 +263,142 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { .expectComplete() .verify(Duration.ofSeconds(5)); - String result = Files.readAllLines(tempFile) - .stream() - .collect(Collectors.joining()); + String result = String.join("", Files.readAllLines(tempFile)); assertEquals("foobarbazqux", result); channel.close(); } + @Test + public void writeAsynchronousFileChannelErrorInFlux() throws Exception { + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + Flux flux = + Flux.just(foo, bar).concatWith(Mono.error(new RuntimeException())); + + Path tempFile = Files.createTempFile("DataBufferUtilsTests", null); + AsynchronousFileChannel channel = + AsynchronousFileChannel.open(tempFile, StandardOpenOption.WRITE); + + Flux writeResult = DataBufferUtils.write(flux, channel); + StepVerifier.create(writeResult) + .consumeNextWith(stringConsumer("foo")) + .consumeNextWith(stringConsumer("bar")) + .expectError(RuntimeException.class) + .verify(); + + String result = String.join("", Files.readAllLines(tempFile)); + + assertEquals("foobar", result); + channel.close(); + } + + + @Test + public void writeAsynchronousFileChannelErrorInWrite() throws Exception { + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + Flux flux = Flux.just(foo, bar); + + AsynchronousFileChannel channel = mock(AsynchronousFileChannel.class); + doAnswer(invocation -> { + ByteBuffer buffer = invocation.getArgument(0); + long pos = invocation.getArgument(1); + CompletionHandler completionHandler = invocation.getArgument(3); + + assertEquals(0, pos); + + int written = buffer.remaining(); + buffer.position(buffer.limit()); + completionHandler.completed(written, buffer); + + return null; + }) + .doAnswer(invocation -> { + ByteBuffer buffer = invocation.getArgument(0); + CompletionHandler completionHandler = + invocation.getArgument(3); + completionHandler.failed(new IOException(), buffer); + return null; + }) + .when(channel).write(isA(ByteBuffer.class), anyLong(), isA(ByteBuffer.class), + isA(CompletionHandler.class)); + + Flux writeResult = DataBufferUtils.write(flux, channel); + StepVerifier.create(writeResult) + .consumeNextWith(stringConsumer("foo")) + .consumeNextWith(stringConsumer("bar")) + .expectError(IOException.class) + .verify(); + + channel.close(); + } + + @Test + public void readAndWriteByteChannel() throws Exception { + Path source = Paths.get( + DataBufferUtilsTests.class.getResource("DataBufferUtilsTests.txt").toURI()); + Flux sourceFlux = + DataBufferUtils + .readByteChannel(() -> FileChannel.open(source, StandardOpenOption.READ), + this.bufferFactory, 3); + + Path destination = Files.createTempFile("DataBufferUtilsTests", null); + WritableByteChannel channel = Files.newByteChannel(destination, StandardOpenOption.WRITE); + + DataBufferUtils.write(sourceFlux, channel) + .subscribe(DataBufferUtils.releaseConsumer(), + throwable -> fail(throwable.getMessage()), + () -> { + try { + String expected = String.join("", Files.readAllLines(source)); + String result = String.join("", Files.readAllLines(destination)); + + assertEquals(expected, result); + channel.close(); + + } + catch (IOException e) { + fail(e.getMessage()); + } + }); + } + + @Test + public void readAndWriteAsynchronousFileChannel() throws Exception { + Path source = Paths.get( + DataBufferUtilsTests.class.getResource("DataBufferUtilsTests.txt").toURI()); + Flux sourceFlux = DataBufferUtils.readAsynchronousFileChannel( + () -> AsynchronousFileChannel.open(source, StandardOpenOption.READ), + this.bufferFactory, 3); + + Path destination = Files.createTempFile("DataBufferUtilsTests", null); + AsynchronousFileChannel channel = + AsynchronousFileChannel.open(destination, StandardOpenOption.WRITE); + + CountDownLatch latch = new CountDownLatch(1); + + DataBufferUtils.write(sourceFlux, channel) + .subscribe(DataBufferUtils::release, + throwable -> fail(throwable.getMessage()), + () -> { + try { + String expected = String.join("", Files.readAllLines(source)); + String result = String.join("", Files.readAllLines(destination)); + + assertEquals(expected, result); + channel.close(); + latch.countDown(); + + } + catch (IOException e) { + fail(e.getMessage()); + } + }); + + latch.await(); + } + @Test public void takeUntilByteCount() { @@ -314,7 +488,8 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { .thenAnswer(putByte('c')) .thenReturn(-1); - Flux read = DataBufferUtils.readByteChannel(() -> channel, this.bufferFactory, 1); + Flux read = + DataBufferUtils.readByteChannel(() -> channel, this.bufferFactory, 1); StepVerifier.create(read) .consumeNextWith(stringConsumer("a")) @@ -343,9 +518,10 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { StepVerifier.create(result) .consumeNextWith(dataBuffer -> { - assertEquals("foobarbaz", DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8)); - release(dataBuffer); - }) + assertEquals("foobarbaz", + DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8)); + release(dataBuffer); + }) .verifyComplete(); }