Browse Source

DataBufferUtils does not release DataBuffer on error cases

This commit makes sure that in DataBufferUtils.write, any received data
buffers are returned as part of the returned flux, even when an error
occurs or is received.

Issue: SPR-16782

(cherry picked from commit 1a0522b805)
pull/1998/head
Arjen Poutsma 6 years ago
parent
commit
952315c333
  1. 57
      spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java
  2. 206
      spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java

57
spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java

@ -31,6 +31,7 @@ import java.nio.file.StandardOpenOption; @@ -31,6 +31,7 @@ import java.nio.file.StandardOpenOption;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.IntPredicate;
@ -336,6 +337,7 @@ public abstract class DataBufferUtils { @@ -336,6 +337,7 @@ public abstract class DataBufferUtils {
sink.next(dataBuffer);
}
catch (IOException ex) {
sink.next(dataBuffer);
sink.error(ex);
}
@ -355,6 +357,26 @@ public abstract class DataBufferUtils { @@ -355,6 +357,26 @@ public abstract class DataBufferUtils {
* @param channel the channel to write to
* @return a flux containing the same buffers as in {@code source}, that starts the writing
* process when subscribed to, and that publishes any writing errors and the completion signal
* @since 5.0.10
*/
public static Flux<DataBuffer> write(
Publisher<DataBuffer> source, AsynchronousFileChannel channel) {
return write(source, channel, 0);
}
/**
* Write the given stream of {@link DataBuffer DataBuffers} to the given {@code AsynchronousFileChannel}.
* Does <strong>not</strong> close the channel when the flux is terminated, and does
* <strong>not</strong> {@linkplain #release(DataBuffer) release} the data buffers in the
* source. If releasing is required, then subscribe to the returned {@code Flux} with a
* {@link #releaseConsumer()}.
* <p>Note that the writing process does not start until the returned {@code Flux} is subscribed to.
* @param source the stream of data buffers to be written
* @param channel the channel to write to
* @param position the file position at which the write is to begin; must be non-negative
* @return a flux containing the same buffers as in {@code source}, that starts the writing
* process when subscribed to, and that publishes any writing errors and the completion signal
*/
public static Flux<DataBuffer> write(
Publisher<DataBuffer> source, AsynchronousFileChannel channel, long position) {
@ -610,10 +632,11 @@ public abstract class DataBufferUtils { @@ -610,10 +632,11 @@ public abstract class DataBufferUtils {
private final AtomicBoolean completed = new AtomicBoolean();
private final AtomicReference<Throwable> error = new AtomicReference<>();
private final AtomicLong position;
@Nullable
private DataBuffer dataBuffer;
private final AtomicReference<DataBuffer> dataBuffer = new AtomicReference<>();
public AsynchronousFileChannelWriteCompletionHandler(
FluxSink<DataBuffer> sink, AsynchronousFileChannel channel, long position) {
@ -630,21 +653,27 @@ public abstract class DataBufferUtils { @@ -630,21 +653,27 @@ public abstract class DataBufferUtils {
@Override
protected void hookOnNext(DataBuffer value) {
this.dataBuffer = value;
if (!this.dataBuffer.compareAndSet(null, value)) {
throw new IllegalStateException();
}
ByteBuffer byteBuffer = value.asByteBuffer();
this.channel.write(byteBuffer, this.position.get(), byteBuffer, this);
}
@Override
protected void hookOnError(Throwable throwable) {
this.sink.error(throwable);
this.error.set(throwable);
if (this.dataBuffer.get() == null) {
this.sink.error(throwable);
}
}
@Override
protected void hookOnComplete() {
this.completed.set(true);
if (this.dataBuffer == null) {
if (this.dataBuffer.get() == null) {
this.sink.complete();
}
}
@ -656,11 +685,13 @@ public abstract class DataBufferUtils { @@ -656,11 +685,13 @@ public abstract class DataBufferUtils {
this.channel.write(byteBuffer, pos, byteBuffer, this);
return;
}
if (this.dataBuffer != null) {
this.sink.next(this.dataBuffer);
this.dataBuffer = null;
sinkDataBuffer();
Throwable throwable = this.error.get();
if (throwable != null) {
this.sink.error(throwable);
}
if (this.completed.get()) {
else if (this.completed.get()) {
this.sink.complete();
}
else {
@ -670,8 +701,16 @@ public abstract class DataBufferUtils { @@ -670,8 +701,16 @@ public abstract class DataBufferUtils {
@Override
public void failed(Throwable exc, ByteBuffer byteBuffer) {
sinkDataBuffer();
this.sink.error(exc);
}
private void sinkDataBuffer() {
DataBuffer dataBuffer = this.dataBuffer.get();
Assert.state(dataBuffer != null, "DataBuffer should not be null");
this.sink.next(dataBuffer);
this.dataBuffer.set(null);
}
}

206
spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java

@ -16,10 +16,12 @@ @@ -16,10 +16,12 @@
package org.springframework.core.io.buffer;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousFileChannel;
import java.nio.channels.CompletionHandler;
import java.nio.channels.FileChannel;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
@ -29,7 +31,7 @@ import java.nio.file.Path; @@ -29,7 +31,7 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.time.Duration;
import java.util.stream.Collectors;
import java.util.concurrent.CountDownLatch;
import io.netty.buffer.ByteBuf;
import org.junit.Test;
@ -160,9 +162,7 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { @@ -160,9 +162,7 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase {
.expectComplete()
.verify(Duration.ofSeconds(5));
String result = Files.readAllLines(tempFile)
.stream()
.collect(Collectors.joining());
String result = String.join("", Files.readAllLines(tempFile));
assertEquals("foobarbazqux", result);
os.close();
@ -188,14 +188,60 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { @@ -188,14 +188,60 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase {
.expectComplete()
.verify(Duration.ofSeconds(5));
String result = Files.readAllLines(tempFile)
.stream()
.collect(Collectors.joining());
String result = String.join("", Files.readAllLines(tempFile));
assertEquals("foobarbazqux", result);
channel.close();
}
@Test
public void writeWritableByteChannelErrorInFlux() throws Exception {
DataBuffer foo = stringBuffer("foo");
DataBuffer bar = stringBuffer("bar");
Flux<DataBuffer> flux = Flux.just(foo, bar).concatWith(Flux.error(new RuntimeException()));
Path tempFile = Files.createTempFile("DataBufferUtilsTests", null);
WritableByteChannel channel = Files.newByteChannel(tempFile, StandardOpenOption.WRITE);
Flux<DataBuffer> writeResult = DataBufferUtils.write(flux, channel);
StepVerifier.create(writeResult)
.consumeNextWith(stringConsumer("foo"))
.consumeNextWith(stringConsumer("bar"))
.expectError()
.verify(Duration.ofSeconds(5));
String result = String.join("", Files.readAllLines(tempFile));
assertEquals("foobar", result);
channel.close();
}
@Test
public void writeWritableByteChannelErrorInWrite() throws Exception {
DataBuffer foo = stringBuffer("foo");
DataBuffer bar = stringBuffer("bar");
Flux<DataBuffer> flux = Flux.just(foo, bar);
WritableByteChannel channel = mock(WritableByteChannel.class);
when(channel.write(any()))
.thenAnswer(invocation -> {
ByteBuffer buffer = invocation.getArgument(0);
int written = buffer.remaining();
buffer.position(buffer.limit());
return written;
})
.thenThrow(new IOException());
Flux<DataBuffer> writeResult = DataBufferUtils.write(flux, channel);
StepVerifier.create(writeResult)
.consumeNextWith(stringConsumer("foo"))
.consumeNextWith(stringConsumer("bar"))
.expectError(IOException.class)
.verify();
channel.close();
}
@Test
public void writeAsynchronousFileChannel() throws Exception {
DataBuffer foo = stringBuffer("foo");
@ -208,7 +254,7 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { @@ -208,7 +254,7 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase {
AsynchronousFileChannel channel =
AsynchronousFileChannel.open(tempFile, StandardOpenOption.WRITE);
Flux<DataBuffer> writeResult = DataBufferUtils.write(flux, channel, 0);
Flux<DataBuffer> writeResult = DataBufferUtils.write(flux, channel);
StepVerifier.create(writeResult)
.consumeNextWith(stringConsumer("foo"))
.consumeNextWith(stringConsumer("bar"))
@ -217,14 +263,142 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { @@ -217,14 +263,142 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase {
.expectComplete()
.verify(Duration.ofSeconds(5));
String result = Files.readAllLines(tempFile)
.stream()
.collect(Collectors.joining());
String result = String.join("", Files.readAllLines(tempFile));
assertEquals("foobarbazqux", result);
channel.close();
}
@Test
public void writeAsynchronousFileChannelErrorInFlux() throws Exception {
DataBuffer foo = stringBuffer("foo");
DataBuffer bar = stringBuffer("bar");
Flux<DataBuffer> flux =
Flux.just(foo, bar).concatWith(Mono.error(new RuntimeException()));
Path tempFile = Files.createTempFile("DataBufferUtilsTests", null);
AsynchronousFileChannel channel =
AsynchronousFileChannel.open(tempFile, StandardOpenOption.WRITE);
Flux<DataBuffer> writeResult = DataBufferUtils.write(flux, channel);
StepVerifier.create(writeResult)
.consumeNextWith(stringConsumer("foo"))
.consumeNextWith(stringConsumer("bar"))
.expectError(RuntimeException.class)
.verify();
String result = String.join("", Files.readAllLines(tempFile));
assertEquals("foobar", result);
channel.close();
}
@Test
public void writeAsynchronousFileChannelErrorInWrite() throws Exception {
DataBuffer foo = stringBuffer("foo");
DataBuffer bar = stringBuffer("bar");
Flux<DataBuffer> flux = Flux.just(foo, bar);
AsynchronousFileChannel channel = mock(AsynchronousFileChannel.class);
doAnswer(invocation -> {
ByteBuffer buffer = invocation.getArgument(0);
long pos = invocation.getArgument(1);
CompletionHandler<Integer, ByteBuffer> completionHandler = invocation.getArgument(3);
assertEquals(0, pos);
int written = buffer.remaining();
buffer.position(buffer.limit());
completionHandler.completed(written, buffer);
return null;
})
.doAnswer(invocation -> {
ByteBuffer buffer = invocation.getArgument(0);
CompletionHandler<Integer, ByteBuffer> completionHandler =
invocation.getArgument(3);
completionHandler.failed(new IOException(), buffer);
return null;
})
.when(channel).write(isA(ByteBuffer.class), anyLong(), isA(ByteBuffer.class),
isA(CompletionHandler.class));
Flux<DataBuffer> writeResult = DataBufferUtils.write(flux, channel);
StepVerifier.create(writeResult)
.consumeNextWith(stringConsumer("foo"))
.consumeNextWith(stringConsumer("bar"))
.expectError(IOException.class)
.verify();
channel.close();
}
@Test
public void readAndWriteByteChannel() throws Exception {
Path source = Paths.get(
DataBufferUtilsTests.class.getResource("DataBufferUtilsTests.txt").toURI());
Flux<DataBuffer> sourceFlux =
DataBufferUtils
.readByteChannel(() -> FileChannel.open(source, StandardOpenOption.READ),
this.bufferFactory, 3);
Path destination = Files.createTempFile("DataBufferUtilsTests", null);
WritableByteChannel channel = Files.newByteChannel(destination, StandardOpenOption.WRITE);
DataBufferUtils.write(sourceFlux, channel)
.subscribe(DataBufferUtils.releaseConsumer(),
throwable -> fail(throwable.getMessage()),
() -> {
try {
String expected = String.join("", Files.readAllLines(source));
String result = String.join("", Files.readAllLines(destination));
assertEquals(expected, result);
channel.close();
}
catch (IOException e) {
fail(e.getMessage());
}
});
}
@Test
public void readAndWriteAsynchronousFileChannel() throws Exception {
Path source = Paths.get(
DataBufferUtilsTests.class.getResource("DataBufferUtilsTests.txt").toURI());
Flux<DataBuffer> sourceFlux = DataBufferUtils.readAsynchronousFileChannel(
() -> AsynchronousFileChannel.open(source, StandardOpenOption.READ),
this.bufferFactory, 3);
Path destination = Files.createTempFile("DataBufferUtilsTests", null);
AsynchronousFileChannel channel =
AsynchronousFileChannel.open(destination, StandardOpenOption.WRITE);
CountDownLatch latch = new CountDownLatch(1);
DataBufferUtils.write(sourceFlux, channel)
.subscribe(DataBufferUtils::release,
throwable -> fail(throwable.getMessage()),
() -> {
try {
String expected = String.join("", Files.readAllLines(source));
String result = String.join("", Files.readAllLines(destination));
assertEquals(expected, result);
channel.close();
latch.countDown();
}
catch (IOException e) {
fail(e.getMessage());
}
});
latch.await();
}
@Test
public void takeUntilByteCount() {
@ -314,7 +488,8 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { @@ -314,7 +488,8 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase {
.thenAnswer(putByte('c'))
.thenReturn(-1);
Flux<DataBuffer> read = DataBufferUtils.readByteChannel(() -> channel, this.bufferFactory, 1);
Flux<DataBuffer> read =
DataBufferUtils.readByteChannel(() -> channel, this.bufferFactory, 1);
StepVerifier.create(read)
.consumeNextWith(stringConsumer("a"))
@ -343,9 +518,10 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { @@ -343,9 +518,10 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase {
StepVerifier.create(result)
.consumeNextWith(dataBuffer -> {
assertEquals("foobarbaz", DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8));
release(dataBuffer);
})
assertEquals("foobarbaz",
DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8));
release(dataBuffer);
})
.verifyComplete();
}

Loading…
Cancel
Save