From 196c0adf47868467ec9242daea81ce7e0246a152 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Mon, 3 Sep 2018 10:10:14 +0200 Subject: [PATCH] Fixed DataBufferUtils.join leak for error in source This commit fixes an issue where DataBufferUtils.join() would not release databuffers that preceded an error signal. Issue: SPR-17025 --- .../core/io/buffer/DataBufferUtils.java | 169 +++++++++++++++++- .../core/io/buffer/DataBufferUtilsTests.java | 21 ++- 2 files changed, 185 insertions(+), 5 deletions(-) diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java index 74544ffb73..c923504964 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DataBufferUtils.java @@ -32,6 +32,7 @@ import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; +import java.util.function.IntPredicate; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; @@ -473,6 +474,9 @@ public abstract class DataBufferUtils { * Depending on the {@link DataBuffer} implementation, the returned buffer may be a single * buffer containing all data of the provided buffers, or it may be a true composite that * contains references to the buffers. + *

If {@code dataBuffers} contains an error signal, then all buffers that preceded the error + * will be {@linkplain #release(DataBuffer) released}, and the error is stored in the + * returned {@code Mono}. * @param dataBuffers the data buffers that are to be composed * @return a buffer that is composed from the {@code dataBuffers} argument * @since 5.0.3 @@ -481,14 +485,26 @@ public abstract class DataBufferUtils { Assert.notNull(dataBuffers, "'dataBuffers' must not be null"); return Flux.from(dataBuffers) + .onErrorResume(DataBufferUtils::exceptionDataBuffer) .collectList() .filter(list -> !list.isEmpty()) - .map(list -> { + .flatMap(list -> { + for (int i = 0; i < list.size(); i++) { + DataBuffer dataBuffer = list.get(i); + if (dataBuffer instanceof ExceptionDataBuffer) { + list.subList(0, i).forEach(DataBufferUtils::release); + return Mono.error(((ExceptionDataBuffer) dataBuffer).throwable()); + } + } DataBufferFactory bufferFactory = list.get(0).factory(); - return bufferFactory.join(list); + return Mono.just(bufferFactory.join(list)); }); } + private static Mono exceptionDataBuffer(Throwable throwable) { + return Mono.just(new ExceptionDataBuffer(throwable)); + } + private static class ReadableByteChannelGenerator implements Consumer> { @@ -658,4 +674,153 @@ public abstract class DataBufferUtils { } } + /** + * DataBuffer implementation that holds a {@link Throwable}, used in {@link #join(Publisher)}. + */ + private static final class ExceptionDataBuffer implements DataBuffer { + + private final Throwable throwable; + + + public ExceptionDataBuffer(Throwable throwable) { + this.throwable = throwable; + } + + public Throwable throwable() { + return this.throwable; + } + + // Unsupported + + @Override + public DataBufferFactory factory() { + throw new UnsupportedOperationException(); + } + + @Override + public int indexOf(IntPredicate predicate, int fromIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public int lastIndexOf(IntPredicate predicate, int fromIndex) { + throw new UnsupportedOperationException(); + } + + @Override + public int readableByteCount() { + throw new UnsupportedOperationException(); + } + + @Override + public int writableByteCount() { + throw new UnsupportedOperationException(); + } + + @Override + public int capacity() { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer capacity(int capacity) { + throw new UnsupportedOperationException(); + } + + @Override + public int readPosition() { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer readPosition(int readPosition) { + throw new UnsupportedOperationException(); + } + + @Override + public int writePosition() { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer writePosition(int writePosition) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int index) { + throw new UnsupportedOperationException(); + } + + @Override + public byte read() { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer read(byte[] destination) { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer read(byte[] destination, int offset, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer write(byte b) { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer write(byte[] source) { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer write(byte[] source, int offset, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer write(DataBuffer... buffers) { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer write(ByteBuffer... buffers) { + throw new UnsupportedOperationException(); + } + + @Override + public DataBuffer slice(int index, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuffer asByteBuffer() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuffer asByteBuffer(int index, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream asInputStream() { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream asInputStream(boolean releaseOnClose) { + throw new UnsupportedOperationException(); + } + + @Override + public OutputStream asOutputStream() { + throw new UnsupportedOperationException(); + } + } + } diff --git a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java index 4b7ddb7dc3..f8419d40cc 100644 --- a/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java +++ b/spring-core/src/test/java/org/springframework/core/io/buffer/DataBufferUtilsTests.java @@ -35,6 +35,7 @@ import io.netty.buffer.ByteBuf; import org.junit.Test; import org.mockito.stubbing.Answer; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import org.springframework.core.io.ClassPathResource; @@ -338,12 +339,26 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { DataBuffer bar = stringBuffer("bar"); DataBuffer baz = stringBuffer("baz"); Flux flux = Flux.just(foo, bar, baz); + Mono result = DataBufferUtils.join(flux); - DataBuffer result = DataBufferUtils.join(flux).block(Duration.ofSeconds(5)); + StepVerifier.create(result) + .consumeNextWith(dataBuffer -> { + assertEquals("foobarbaz", DataBufferTestUtils.dumpString(dataBuffer, StandardCharsets.UTF_8)); + release(dataBuffer); + }) + .verifyComplete(); + } - assertEquals("foobarbaz", DataBufferTestUtils.dumpString(result, StandardCharsets.UTF_8)); + @Test + public void joinErrors() { + DataBuffer foo = stringBuffer("foo"); + DataBuffer bar = stringBuffer("bar"); + Flux flux = Flux.just(foo, bar).mergeWith(Flux.error(new RuntimeException())); + Mono result = DataBufferUtils.join(flux); - release(result); + StepVerifier.create(result) + .expectError(RuntimeException.class) + .verify(); } }