diff --git a/spring-core/src/main/java/org/springframework/core/io/buffer/DefaultDataBuffer.java b/spring-core/src/main/java/org/springframework/core/io/buffer/DefaultDataBuffer.java index 0bf1d2a35c..75b83a6198 100644 --- a/spring-core/src/main/java/org/springframework/core/io/buffer/DefaultDataBuffer.java +++ b/spring-core/src/main/java/org/springframework/core/io/buffer/DefaultDataBuffer.java @@ -80,10 +80,14 @@ public class DefaultDataBuffer implements DataBuffer { /** - * Directly exposes the native {@code ByteBuffer} that this buffer is based on. + * Directly exposes the native {@code ByteBuffer} that this buffer is based + * on also updating the {@code ByteBuffer's} position and limit to match + * the current {@link #readPosition()} and {@link #readableByteCount()}. * @return the wrapped byte buffer */ public ByteBuffer getNativeBuffer() { + this.byteBuffer.position(this.readPosition); + this.byteBuffer.limit(readableByteCount()); return this.byteBuffer; } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java index 8fc2e59879..983e1448f6 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java @@ -33,7 +33,6 @@ import org.springframework.core.ResolvableType; import org.springframework.core.codec.Encoder; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; -import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.lang.Nullable; import org.springframework.messaging.Message; @@ -148,7 +147,7 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler Encoder encoder = getEncoder(elementType, mimeType); - return Flux.from((Publisher) publisher).concatMap(value -> + return Flux.from((Publisher) publisher).map(value -> encodeValue(value, elementType, encoder, bufferFactory, mimeType, hints)); } @@ -176,7 +175,7 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler } @SuppressWarnings("unchecked") - private Mono encodeValue( + private DataBuffer encodeValue( Object element, ResolvableType elementType, @Nullable Encoder encoder, DataBufferFactory bufferFactory, @Nullable MimeType mimeType, @Nullable Map hints) { @@ -184,13 +183,11 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler if (encoder == null) { encoder = getEncoder(ResolvableType.forInstance(element), mimeType); if (encoder == null) { - return Mono.error(new MessagingException( - "No encoder for " + elementType + ", current value type is " + element.getClass())); + throw new MessagingException( + "No encoder for " + elementType + ", current value type is " + element.getClass()); } } - Mono mono = Mono.just((T) element); - Flux dataBuffers = encoder.encode(mono, bufferFactory, elementType, mimeType, hints); - return DataBufferUtils.join(dataBuffers); + return encoder.encodeValue((T) element, bufferFactory, elementType, mimeType, hints); } /** diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java index c112cb16bd..97e0b2351e 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequester.java @@ -32,7 +32,6 @@ import org.springframework.core.ResolvableType; import org.springframework.core.codec.Decoder; import org.springframework.core.codec.Encoder; import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.MimeType; @@ -124,8 +123,10 @@ final class DefaultRSocketRequester implements RSocketRequester { publisher = adapter.toPublisher(input); } else { - Mono payloadMono = encodeValue(input, ResolvableType.forInstance(input), null) + Mono payloadMono = Mono + .fromCallable(() -> encodeValue(input, ResolvableType.forInstance(input), null)) .map(this::firstPayload) + .doOnDiscard(Payload.class, Payload::release) .switchIfEmpty(emptyPayload()); return new DefaultResponseSpec(payloadMono); } @@ -140,36 +141,36 @@ final class DefaultRSocketRequester implements RSocketRequester { if (adapter != null && !adapter.isMultiValue()) { Mono payloadMono = Mono.from(publisher) - .flatMap(value -> encodeValue(value, dataType, encoder)) + .map(value -> encodeValue(value, dataType, encoder)) .map(this::firstPayload) .switchIfEmpty(emptyPayload()); return new DefaultResponseSpec(payloadMono); } Flux payloadFlux = Flux.from(publisher) - .concatMap(value -> encodeValue(value, dataType, encoder)) + .map(value -> encodeValue(value, dataType, encoder)) .switchOnFirst((signal, inner) -> { DataBuffer data = signal.get(); if (data != null) { - return Flux.concat( - Mono.just(firstPayload(data)), - inner.skip(1).map(PayloadUtils::createPayload)); + return Mono.fromCallable(() -> firstPayload(data)) + .concatWith(inner.skip(1).map(PayloadUtils::createPayload)); } else { return inner.map(PayloadUtils::createPayload); } }) + .doOnDiscard(Payload.class, Payload::release) .switchIfEmpty(emptyPayload()); return new DefaultResponseSpec(payloadFlux); } @SuppressWarnings("unchecked") - private Mono encodeValue(T value, ResolvableType valueType, @Nullable Encoder encoder) { + private DataBuffer encodeValue(T value, ResolvableType valueType, @Nullable Encoder encoder) { if (encoder == null) { encoder = strategies.encoder(ResolvableType.forInstance(value), dataMimeType); } - return DataBufferUtils.join(((Encoder) encoder).encode( - Mono.just(value), strategies.dataBufferFactory(), valueType, dataMimeType, EMPTY_HINTS)); + return ((Encoder) encoder).encodeValue( + value, strategies.dataBufferFactory(), valueType, dataMimeType, EMPTY_HINTS); } private Payload firstPayload(DataBuffer data) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/handler/annotation/support/reactive/MessageMappingMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/handler/annotation/support/reactive/MessageMappingMessageHandlerTests.java index 5496e87eeb..837311c403 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/handler/annotation/support/reactive/MessageMappingMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/handler/annotation/support/reactive/MessageMappingMessageHandlerTests.java @@ -81,7 +81,7 @@ public class MessageMappingMessageHandlerTests { @Test public void handleFluxString() { MessageMappingMessageHandler messsageHandler = initMesssageHandler(); - messsageHandler.handleMessage(message("fluxString", "abc\ndef\nghi")).block(Duration.ofSeconds(5)); + messsageHandler.handleMessage(message("fluxString", "abc", "def", "ghi")).block(Duration.ofSeconds(5)); verifyOutputContent(Arrays.asList("abc::response", "def::response", "ghi::response")); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/handler/annotation/support/reactive/PayloadMethodArgumentResolverTests.java b/spring-messaging/src/test/java/org/springframework/messaging/handler/annotation/support/reactive/PayloadMethodArgumentResolverTests.java index 83e1293cd8..da82e20dea 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/handler/annotation/support/reactive/PayloadMethodArgumentResolverTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/handler/annotation/support/reactive/PayloadMethodArgumentResolverTests.java @@ -129,9 +129,10 @@ public class PayloadMethodArgumentResolverTests { @Test public void validateStringMono() { + TestValidator validator = new TestValidator(); ResolvableType type = ResolvableType.forClassWithGenerics(Mono.class, String.class); MethodParameter param = this.testMethod.arg(type); - Mono mono = resolveValue(param, Mono.just(toDataBuffer("12345")), new TestValidator()); + Mono mono = resolveValue(param, Mono.just(toDataBuffer("12345")), validator); StepVerifier.create(mono).expectNextCount(0) .expectError(MethodArgumentNotValidException.class).verify(); @@ -139,9 +140,11 @@ public class PayloadMethodArgumentResolverTests { @Test public void validateStringFlux() { + TestValidator validator = new TestValidator(); ResolvableType type = ResolvableType.forClassWithGenerics(Flux.class, String.class); MethodParameter param = this.testMethod.arg(type); - Flux flux = resolveValue(param, Mono.just(toDataBuffer("12345678\n12345")), new TestValidator()); + Flux content = Flux.just(toDataBuffer("12345678"), toDataBuffer("12345")); + Flux flux = resolveValue(param, content, validator); StepVerifier.create(flux) .expectNext("12345678") diff --git a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java index fec0ed09ec..3840bae308 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java +++ b/spring-web/src/main/java/org/springframework/http/codec/ServerSentEventHttpMessageWriter.java @@ -18,6 +18,7 @@ package org.springframework.http.codec; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -111,9 +112,9 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter> encode(Publisher input, ResolvableType elementType, - MediaType mediaType, DataBufferFactory factory, Map hints) { + MediaType mediaType, DataBufferFactory bufferFactory, Map hints) { - ResolvableType valueType = (ServerSentEvent.class.isAssignableFrom(elementType.toClass()) ? + ResolvableType dataType = (ServerSentEvent.class.isAssignableFrom(elementType.toClass()) ? elementType.getGeneric() : elementType); return Flux.from(input).map(element -> { @@ -143,12 +144,10 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter flux = Flux.concat( - encodeText(sb, mediaType, factory), - encodeData(data, valueType, mediaType, factory, hints), - encodeText("\n", mediaType, factory)); + Mono bufferMono = Mono.fromCallable(() -> + bufferFactory.join(encodeEvent(sb, data, dataType, mediaType, bufferFactory, hints))); - return flux.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); + return bufferMono.doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release); }); } @@ -160,31 +159,32 @@ public class ServerSentEventHttpMessageWriter implements HttpMessageWriter Flux encodeData(@Nullable T dataValue, ResolvableType valueType, + private List encodeEvent(CharSequence markup, @Nullable T data, ResolvableType dataType, MediaType mediaType, DataBufferFactory factory, Map hints) { - if (dataValue == null) { - return Flux.empty(); - } - - if (dataValue instanceof String) { - String text = (String) dataValue; - return Flux.from(encodeText(StringUtils.replace(text, "\n", "\ndata:") + "\n", mediaType, factory)); - } - - if (this.encoder == null) { - return Flux.error(new CodecException("No SSE encoder configured and the data is not String.")); + List result = new ArrayList<>(4); + result.add(encodeText(markup, mediaType, factory)); + if (data != null) { + if (data instanceof String) { + String dataLine = StringUtils.replace((String) data, "\n", "\ndata:") + "\n"; + result.add(encodeText(dataLine, mediaType, factory)); + } + else if (this.encoder == null) { + throw new CodecException("No SSE encoder configured and the data is not String."); + } + else { + result.add(((Encoder) this.encoder).encodeValue(data, factory, dataType, mediaType, hints)); + result.add(encodeText("\n", mediaType, factory)); + } } - - return ((Encoder) this.encoder) - .encode(Mono.just(dataValue), factory, valueType, mediaType, hints) - .concatWith(encodeText("\n", mediaType, factory)); + result.add(encodeText("\n", mediaType, factory)); + return result; } - private Mono encodeText(CharSequence text, MediaType mediaType, DataBufferFactory bufferFactory) { + private DataBuffer encodeText(CharSequence text, MediaType mediaType, DataBufferFactory bufferFactory) { Assert.notNull(mediaType.getCharset(), "Expected MediaType with charset"); byte[] bytes = text.toString().getBytes(mediaType.getCharset()); - return Mono.just(bufferFactory.wrap(bytes)); // wrapping, not allocating + return bufferFactory.wrap(bytes); // wrapping, not allocating } @Override diff --git a/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageWriterTests.java index 10eb6bffee..6949efb9e9 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageWriterTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/ServerSentEventHttpMessageWriterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; import static org.junit.Assert.*; -import static org.springframework.core.ResolvableType.forClass; +import static org.springframework.core.ResolvableType.*; /** * Unit tests for {@link ServerSentEventHttpMessageWriter}. @@ -88,9 +88,8 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll testWrite(source, outputMessage, ServerSentEvent.class); StepVerifier.create(outputMessage.getBody()) - .consumeNextWith(stringConsumer("id:c42\nevent:foo\nretry:123\n:bla\n:bla bla\n:bla bla bla\ndata:")) - .consumeNextWith(stringConsumer("bar\n")) - .consumeNextWith(stringConsumer("\n")) + .consumeNextWith(stringConsumer( + "id:c42\nevent:foo\nretry:123\n:bla\n:bla bla\n:bla bla bla\ndata:bar\n\n")) .expectComplete() .verify(); } @@ -101,12 +100,8 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll testWrite(source, outputMessage, String.class); StepVerifier.create(outputMessage.getBody()) - .consumeNextWith(stringConsumer("data:")) - .consumeNextWith(stringConsumer("foo\n")) - .consumeNextWith(stringConsumer("\n")) - .consumeNextWith(stringConsumer("data:")) - .consumeNextWith(stringConsumer("bar\n")) - .consumeNextWith(stringConsumer("\n")) + .consumeNextWith(stringConsumer("data:foo\n\n")) + .consumeNextWith(stringConsumer("data:bar\n\n")) .expectComplete() .verify(); } @@ -117,12 +112,8 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll testWrite(source, outputMessage, String.class); StepVerifier.create(outputMessage.getBody()) - .consumeNextWith(stringConsumer("data:")) - .consumeNextWith(stringConsumer("foo\ndata:bar\n")) - .consumeNextWith(stringConsumer("\n")) - .consumeNextWith(stringConsumer("data:")) - .consumeNextWith(stringConsumer("foo\ndata:baz\n")) - .consumeNextWith(stringConsumer("\n")) + .consumeNextWith(stringConsumer("data:foo\ndata:bar\n\n")) + .consumeNextWith(stringConsumer("data:foo\ndata:baz\n\n")) .expectComplete() .verify(); } @@ -136,14 +127,11 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll assertEquals(mediaType, outputMessage.getHeaders().getContentType()); StepVerifier.create(outputMessage.getBody()) - .consumeNextWith(stringConsumer("data:")) .consumeNextWith(dataBuffer -> { - String value = - DataBufferTestUtils.dumpString(dataBuffer, charset); + String value = DataBufferTestUtils.dumpString(dataBuffer, charset); DataBufferUtils.release(dataBuffer); - assertEquals("\u00A3\n", value); + assertEquals("data:\u00A3\n\n", value); }) - .consumeNextWith(stringConsumer("\n")) .expectComplete() .verify(); } @@ -154,14 +142,8 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll testWrite(source, outputMessage, Pojo.class); StepVerifier.create(outputMessage.getBody()) - .consumeNextWith(stringConsumer("data:")) - .consumeNextWith(stringConsumer("{\"foo\":\"foofoo\",\"bar\":\"barbar\"}")) - .consumeNextWith(stringConsumer("\n")) - .consumeNextWith(stringConsumer("\n")) - .consumeNextWith(stringConsumer("data:")) - .consumeNextWith(stringConsumer("{\"foo\":\"foofoofoo\",\"bar\":\"barbarbar\"}")) - .consumeNextWith(stringConsumer("\n")) - .consumeNextWith(stringConsumer("\n")) + .consumeNextWith(stringConsumer("data:{\"foo\":\"foofoo\",\"bar\":\"barbar\"}\n\n")) + .consumeNextWith(stringConsumer("data:{\"foo\":\"foofoofoo\",\"bar\":\"barbarbar\"}\n\n")) .expectComplete() .verify(); } @@ -175,18 +157,12 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll testWrite(source, outputMessage, Pojo.class); StepVerifier.create(outputMessage.getBody()) - .consumeNextWith(stringConsumer("data:")) - .consumeNextWith(stringConsumer("{\n" + + .consumeNextWith(stringConsumer("data:{\n" + "data: \"foo\" : \"foofoo\",\n" + - "data: \"bar\" : \"barbar\"\n" + "data:}")) - .consumeNextWith(stringConsumer("\n")) - .consumeNextWith(stringConsumer("\n")) - .consumeNextWith(stringConsumer("data:")) - .consumeNextWith(stringConsumer("{\n" + + "data: \"bar\" : \"barbar\"\n" + "data:}\n\n")) + .consumeNextWith(stringConsumer("data:{\n" + "data: \"foo\" : \"foofoofoo\",\n" + - "data: \"bar\" : \"barbarbar\"\n" + "data:}")) - .consumeNextWith(stringConsumer("\n")) - .consumeNextWith(stringConsumer("\n")) + "data: \"bar\" : \"barbarbar\"\n" + "data:}\n\n")) .expectComplete() .verify(); } @@ -200,28 +176,10 @@ public class ServerSentEventHttpMessageWriterTests extends AbstractDataBufferAll assertEquals(mediaType, outputMessage.getHeaders().getContentType()); StepVerifier.create(outputMessage.getBody()) - .consumeNextWith(dataBuffer1 -> { - String value1 = - DataBufferTestUtils.dumpString(dataBuffer1, charset); - DataBufferUtils.release(dataBuffer1); - assertEquals("data:", value1); - }) .consumeNextWith(dataBuffer -> { String value = DataBufferTestUtils.dumpString(dataBuffer, charset); DataBufferUtils.release(dataBuffer); - assertEquals("{\"foo\":\"foo\uD834\uDD1E\",\"bar\":\"bar\uD834\uDD1E\"}", value); - }) - .consumeNextWith(dataBuffer2 -> { - String value2 = - DataBufferTestUtils.dumpString(dataBuffer2, charset); - DataBufferUtils.release(dataBuffer2); - assertEquals("\n", value2); - }) - .consumeNextWith(dataBuffer3 -> { - String value3 = - DataBufferTestUtils.dumpString(dataBuffer3, charset); - DataBufferUtils.release(dataBuffer3); - assertEquals("\n", value3); + assertEquals("data:{\"foo\":\"foo\uD834\uDD1E\",\"bar\":\"bar\uD834\uDD1E\"}\n\n", value); }) .expectComplete() .verify();