diff --git a/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/DataBuffer.java b/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/DataBuffer.java index cf4770adee..31cf7f63ac 100644 --- a/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/DataBuffer.java +++ b/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/DataBuffer.java @@ -19,6 +19,7 @@ package org.springframework.core.io.buffer; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; +import java.util.function.IntPredicate; /** * Basic abstraction over byte buffers. @@ -34,12 +35,22 @@ public interface DataBuffer { DataBufferAllocator allocator(); /** - * Gets the byte at the specified index. - * @param index the index - * @return the byte at the specified index - * @throws IndexOutOfBoundsException if the given index is out of bounds + * Returns the index of the first byte in this buffer that matches the given + * predicate. + * @param predicate the predicate to match + * @return the index of the first byte that matches {@code predicate}; or {@code -1} + * if none match */ - byte get(int index); + int indexOf(IntPredicate predicate); + + /** + * Returns the index of the last byte in this buffer that matches the given + * predicate. + * @param predicate the predicate to match + * @return the index of the last byte that matches {@code predicate}; or {@code -1} + * if none match + */ + int lastIndexOf(IntPredicate predicate); /** * Returns the number of bytes that can be read from this data buffer. @@ -113,11 +124,22 @@ public interface DataBuffer { */ DataBuffer write(ByteBuffer... buffers); + /** + * Creates a new {@code DataBuffer} whose contents is a shared subsequence of this + * data buffer's content. Data between this data buffer and the returned buffer is + * shared; though changes in the returned buffer's position will not be reflected + * in the reading nor writing position of this data buffer. + * @param index the index at which to start the slice + * @param length the length of the slice + * @return the specified slice of this data buffer + */ + DataBuffer slice(int index, int length); + /** * Exposes this buffer's bytes as a {@link ByteBuffer}. Data between this {@code * DataBuffer} and the returned {@code ByteBuffer} is shared; though changes in the * returned buffer's {@linkplain ByteBuffer#position() position} will not be reflected - * in the position(s) of this data buffer. + * in the reading nor writing position of this data buffer. * @return this data buffer as a byte buffer */ ByteBuffer asByteBuffer(); diff --git a/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/DefaultDataBuffer.java b/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/DefaultDataBuffer.java index 823f644eec..940c0d1a8b 100644 --- a/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/DefaultDataBuffer.java +++ b/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/DefaultDataBuffer.java @@ -22,6 +22,7 @@ import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.function.Function; +import java.util.function.IntPredicate; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; @@ -82,8 +83,25 @@ public class DefaultDataBuffer implements DataBuffer { } @Override - public byte get(int index) { - return this.byteBuffer.get(index); + public int indexOf(IntPredicate predicate) { + for (int i = 0; i < readableByteCount(); i++) { + byte b = this.byteBuffer.get(i); + if (predicate.test(b)) { + return i; + } + } + return -1; + } + + @Override + public int lastIndexOf(IntPredicate predicate) { + for (int i = readableByteCount() - 1; i >= 0; i--) { + byte b = this.byteBuffer.get(i); + if (predicate.test(b)) { + return i; + } + } + return -1; } @Override @@ -120,14 +138,16 @@ public class DefaultDataBuffer implements DataBuffer { */ private T readInternal(Function function) { this.byteBuffer.position(this.readPosition); - T result = function.apply(this.byteBuffer); - this.readPosition = this.byteBuffer.position(); - return result; + try { + return function.apply(this.byteBuffer); + } + finally { + this.readPosition = this.byteBuffer.position(); + } } @Override public DefaultDataBuffer write(byte b) { - ensureExtraCapacity(1); writeInternal(buffer -> buffer.put(b)); @@ -184,9 +204,27 @@ public class DefaultDataBuffer implements DataBuffer { */ private T writeInternal(Function function) { this.byteBuffer.position(this.writePosition); - T result = function.apply(this.byteBuffer); - this.writePosition = this.byteBuffer.position(); - return result; + try { + return function.apply(this.byteBuffer); + } + finally { + this.writePosition = this.byteBuffer.position(); + } + } + + @Override + public DataBuffer slice(int index, int length) { + int oldPosition = this.byteBuffer.position(); + try { + this.byteBuffer.position(index); + ByteBuffer slice = this.byteBuffer.slice(); + slice.limit(length); + return new SlicedDefaultDataBuffer(slice, 0, length, this.allocator); + } + finally { + this.byteBuffer.position(oldPosition); + } + } @Override @@ -214,7 +252,7 @@ public class DefaultDataBuffer implements DataBuffer { } } - private void grow(int minCapacity) { + void grow(int minCapacity) { ByteBuffer oldBuffer = this.byteBuffer; ByteBuffer newBuffer = (oldBuffer.isDirect() ? ByteBuffer.allocateDirect(minCapacity) : @@ -227,6 +265,7 @@ public class DefaultDataBuffer implements DataBuffer { oldBuffer.clear(); } + @Override public int hashCode() { return this.byteBuffer.hashCode(); @@ -294,4 +333,18 @@ public class DefaultDataBuffer implements DataBuffer { writeInternal(buffer -> buffer.put(bytes, off, len)); } } + + private static class SlicedDefaultDataBuffer extends DefaultDataBuffer { + + SlicedDefaultDataBuffer(ByteBuffer byteBuffer, int readPosition, + int writePosition, DefaultDataBufferAllocator allocator) { + super(byteBuffer, readPosition, writePosition, allocator); + } + + @Override + void grow(int minCapacity) { + throw new UnsupportedOperationException( + "Growing the capacity of a sliced buffer is not supported"); + } + } } diff --git a/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/NettyDataBuffer.java b/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/NettyDataBuffer.java index 7078142d6e..c5aa526c04 100644 --- a/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/NettyDataBuffer.java +++ b/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/NettyDataBuffer.java @@ -20,6 +20,7 @@ import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.function.IntPredicate; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufInputStream; @@ -68,8 +69,15 @@ public class NettyDataBuffer implements PooledDataBuffer { } @Override - public byte get(int index) { - return this.byteBuf.getByte(index); + public int indexOf(IntPredicate predicate) { + IntPredicate negated = predicate.negate(); + return this.byteBuf.forEachByte(negated::test); + } + + @Override + public int lastIndexOf(IntPredicate predicate) { + IntPredicate negated = predicate.negate(); + return this.byteBuf.forEachByteDesc(negated::test); } @Override @@ -166,6 +174,12 @@ public class NettyDataBuffer implements PooledDataBuffer { return this; } + @Override + public DataBuffer slice(int index, int length) { + ByteBuf slice = this.byteBuf.slice(index, length); + return new NettyDataBuffer(slice, this.allocator); + } + @Override public ByteBuffer asByteBuffer() { return this.byteBuf.nioBuffer(); @@ -183,8 +197,7 @@ public class NettyDataBuffer implements PooledDataBuffer { @Override public PooledDataBuffer retain() { - this.byteBuf.retain(); - return this; + return new NettyDataBuffer(this.byteBuf.retain(), allocator); } @Override diff --git a/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/support/DataBufferUtils.java b/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/support/DataBufferUtils.java index 85b475e107..5c264a6dad 100644 --- a/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/support/DataBufferUtils.java +++ b/spring-web-reactive/src/main/java/org/springframework/core/io/buffer/support/DataBufferUtils.java @@ -21,15 +21,16 @@ import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.nio.channels.ReadableByteChannel; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.IntPredicate; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import reactor.core.publisher.Flux; -import reactor.core.publisher.FluxSource; import reactor.core.subscriber.SignalEmitter; import org.springframework.core.io.buffer.DataBuffer; @@ -102,7 +103,79 @@ public abstract class DataBufferUtils { Assert.notNull(publisher, "'publisher' must not be null"); Assert.isTrue(maxByteCount >= 0, "'maxByteCount' must be a positive number"); - return new TakeByteUntilCount(publisher, maxByteCount); + AtomicLong byteCountDown = new AtomicLong(maxByteCount); + + return Flux.from(publisher). + takeWhile(dataBuffer -> { + int delta = -dataBuffer.readableByteCount(); + long currentCount = byteCountDown.getAndAdd(delta); + return currentCount >= 0; + }). + map(dataBuffer -> { + long currentCount = byteCountDown.get(); + if (currentCount >= 0) { + return dataBuffer; + } + else { + // last buffer + int size = (int) (currentCount + dataBuffer.readableByteCount()); + return dataBuffer.slice(0, size); + } + }); + } + + /** + * Tokenize the {@link DataBuffer} using the given delimiter + * function. Does not include the delimiter in the result. + * @param dataBuffer the data buffer to tokenize + * @param delimiter the delimiter function + * @return the tokens + */ + public static List tokenize(DataBuffer dataBuffer, + IntPredicate delimiter) { + Assert.notNull(dataBuffer, "'dataBuffer' must not be null"); + Assert.notNull(delimiter, "'delimiter' must not be null"); + + List results = new ArrayList(); + int idx; + do { + idx = dataBuffer.indexOf(delimiter); + if (idx < 0) { + results.add(dataBuffer); + } + else { + if (idx > 0) { + DataBuffer slice = dataBuffer.slice(0, idx); + slice = retain(slice); + results.add(slice); + } + int remainingLen = dataBuffer.readableByteCount() - (idx + 1); + if (remainingLen > 0) { + dataBuffer = dataBuffer.slice(idx + 1, remainingLen); + } + else { + release(dataBuffer); + idx = -1; + } + } + } + while (idx != -1); + return Collections.unmodifiableList(results); + } + + /** + * Retains the given data buffer, it it is a {@link PooledDataBuffer}. + * @param dataBuffer the data buffer to retain + * @return the retained buffer + */ + @SuppressWarnings("unchecked") + public static T retain(T dataBuffer) { + if (dataBuffer instanceof PooledDataBuffer) { + return (T) ((PooledDataBuffer) dataBuffer).retain(); + } + else { + return dataBuffer; + } } /** @@ -117,63 +190,6 @@ public abstract class DataBufferUtils { return false; } - private static final class TakeByteUntilCount extends FluxSource { - - final long maxByteCount; - - TakeByteUntilCount(Publisher source, long maxByteCount) { - super(source); - this.maxByteCount = maxByteCount; - } - - @Override - public void subscribe(Subscriber subscriber) { - source.subscribe(new Subscriber() { - - private Subscription subscription; - - private final AtomicLong byteCount = new AtomicLong(); - - @Override - public void onSubscribe(Subscription s) { - this.subscription = s; - subscriber.onSubscribe(s); - } - - @Override - public void onNext(DataBuffer dataBuffer) { - int delta = dataBuffer.readableByteCount(); - long currentCount = this.byteCount.addAndGet(delta); - if (currentCount > maxByteCount) { - int size = (int) (maxByteCount - currentCount + delta); - ByteBuffer byteBuffer = - (ByteBuffer) dataBuffer.asByteBuffer().limit(size); - DataBuffer partialBuffer = - dataBuffer.allocator().allocateBuffer(size); - partialBuffer.write(byteBuffer); - - subscriber.onNext(partialBuffer); - subscriber.onComplete(); - this.subscription.cancel(); - } - else { - subscriber.onNext(dataBuffer); - } - } - - @Override - public void onError(Throwable t) { - subscriber.onError(t); - } - - @Override - public void onComplete() { - subscriber.onComplete(); - } - }); - } - } - private static class ReadableByteChannelGenerator implements BiFunction, ReadableByteChannel> { diff --git a/spring-web-reactive/src/test/java/org/springframework/core/io/buffer/DataBufferTests.java b/spring-web-reactive/src/test/java/org/springframework/core/io/buffer/DataBufferTests.java index cf9806f098..4014fd629a 100644 --- a/spring-web-reactive/src/test/java/org/springframework/core/io/buffer/DataBufferTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/core/io/buffer/DataBufferTests.java @@ -24,8 +24,7 @@ import java.util.Arrays; import org.junit.Test; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.Assert.*; /** * @author Arjen Poutsma @@ -184,7 +183,65 @@ public class DataBufferTests extends AbstractDataBufferAllocatingTestCase { assertArrayEquals(new byte[]{'b', 'c'}, resultBytes); release(buffer); + } + + @Test + public void indexOf() { + DataBuffer buffer = createDataBuffer(3); + buffer.write(new byte[]{'a', 'b', 'c'}); + + int result = buffer.indexOf(b -> b == 'c'); + assertEquals(2, result); + + result = buffer.indexOf(b -> b == 'z'); + assertEquals(-1, result); + + release(buffer); + } + + @Test + public void lastIndexOf() { + DataBuffer buffer = createDataBuffer(3); + buffer.write(new byte[]{'a', 'b', 'c'}); + + int result = buffer.lastIndexOf(b -> b == 'b'); + assertEquals(1, result); + result = buffer.lastIndexOf(b -> b == 'z'); + assertEquals(-1, result); + + release(buffer); + } + + @Test + public void slice() { + DataBuffer buffer = createDataBuffer(3); + buffer.write(new byte[]{'a', 'b'}); + + DataBuffer slice = buffer.slice(1, 2); + assertEquals(2, slice.readableByteCount()); + try { + slice.write((byte) 0); + fail("IndexOutOfBoundsException expected"); + } + catch (Exception ignored) { + } + buffer.write((byte) 'c'); + + assertEquals(3, buffer.readableByteCount()); + byte[] result = new byte[3]; + buffer.read(result); + + assertArrayEquals(new byte[]{'a', 'b', 'c'}, result); + + assertEquals(2, slice.readableByteCount()); + result = new byte[2]; + slice.read(result); + + assertArrayEquals(new byte[]{'b', 'c'}, result); + + + release(buffer); } diff --git a/spring-web-reactive/src/test/java/org/springframework/core/io/buffer/support/DataBufferUtilsTests.java b/spring-web-reactive/src/test/java/org/springframework/core/io/buffer/support/DataBufferUtilsTests.java index f4971bc5c0..e00aa4125b 100644 --- a/spring-web-reactive/src/test/java/org/springframework/core/io/buffer/support/DataBufferUtilsTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/core/io/buffer/support/DataBufferUtilsTests.java @@ -19,8 +19,10 @@ package org.springframework.core.io.buffer.support; import java.io.InputStream; import java.net.URI; import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; +import java.util.List; import org.junit.Test; import reactor.core.publisher.Flux; @@ -29,6 +31,7 @@ import reactor.core.test.TestSubscriber; import org.springframework.core.io.buffer.AbstractDataBufferAllocatingTestCase; import org.springframework.core.io.buffer.DataBuffer; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; /** @@ -103,7 +106,25 @@ public class DataBufferUtilsTests extends AbstractDataBufferAllocatingTestCase { assertComplete(). assertValuesWith(stringConsumer("foo"), stringConsumer("ba")); - release(bar, baz); + release(baz); + } + + @Test + public void tokenize() { + DataBuffer dataBuffer = stringBuffer("-foo--bar-"); + + List results = DataBufferUtils.tokenize(dataBuffer, b -> b == '-'); + assertEquals(2, results.size()); + + DataBuffer result = results.get(0); + String value = DataBufferTestUtils.dumpString(result, StandardCharsets.UTF_8); + assertEquals("foo", value); + + result = results.get(1); + value = DataBufferTestUtils.dumpString(result, StandardCharsets.UTF_8); + assertEquals("bar", value); + + results.stream().forEach(b -> release(b)); }