diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java index bdaea4dff5..1ae547a57d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java @@ -87,7 +87,7 @@ final class DefaultRSocketRequesterBuilder implements RSocketRequester.Builder { this.factoryConfigurers.forEach(configurer -> configurer.accept(factory)); return factory.transport(transport).start() - .map(rsocket -> RSocketRequester.create(rsocket, dataMimeType, strategies)); + .map(rsocket -> new DefaultRSocketRequester(rsocket, dataMimeType, strategies)); }); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessageHandlerAcceptor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessageHandlerAcceptor.java index 600633c949..ea5d02a65c 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessageHandlerAcceptor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessageHandlerAcceptor.java @@ -72,7 +72,10 @@ public final class MessageHandlerAcceptor extends RSocketMessageHandler } private MessagingRSocket createRSocket(RSocket rsocket) { - return new MessagingRSocket(this::handleMessage, rsocket, this.defaultDataMimeType, getRSocketStrategies()); + return new MessagingRSocket(this::handleMessage, + RSocketRequester.wrap(rsocket, this.defaultDataMimeType, getRSocketStrategies()), + this.defaultDataMimeType, + getRSocketStrategies().dataBufferFactory()); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessagingRSocket.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessagingRSocket.java index 604ce47178..e36875966c 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessagingRSocket.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MessagingRSocket.java @@ -61,18 +61,18 @@ class MessagingRSocket extends AbstractRSocket { @Nullable private MimeType dataMimeType; - private final RSocketStrategies strategies; + private final DataBufferFactory bufferFactory; - MessagingRSocket(Function, Mono> handler, RSocket sendingRSocket, - @Nullable MimeType defaultDataMimeType, RSocketStrategies strategies) { + MessagingRSocket(Function, Mono> handler, RSocketRequester requester, + @Nullable MimeType defaultDataMimeType, DataBufferFactory bufferFactory) { Assert.notNull(handler, "'handler' is required"); - Assert.notNull(sendingRSocket, "'sendingRSocket' is required"); + Assert.notNull(requester, "'requester' is required"); this.handler = handler; - this.requester = RSocketRequester.create(sendingRSocket, defaultDataMimeType, strategies); + this.requester = requester; this.dataMimeType = defaultDataMimeType; - this.strategies = strategies; + this.bufferFactory = bufferFactory; } @@ -175,7 +175,7 @@ class MessagingRSocket extends AbstractRSocket { } private DataBuffer retainDataAndReleasePayload(Payload payload) { - return PayloadUtils.retainDataAndReleasePayload(payload, this.strategies.dataBufferFactory()); + return PayloadUtils.retainDataAndReleasePayload(payload, this.bufferFactory); } private MessageHeaders createHeaders(String destination, @Nullable MonoProcessor replyMono) { @@ -189,8 +189,7 @@ class MessagingRSocket extends AbstractRSocket { if (replyMono != null) { headers.setHeader(RSocketPayloadReturnValueHandler.RESPONSE_HEADER, replyMono); } - DataBufferFactory bufferFactory = this.strategies.dataBufferFactory(); - headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, bufferFactory); + headers.setHeader(HandlerMethodReturnValueHandler.DATA_BUFFER_FACTORY_HEADER, this.bufferFactory); return headers.getMessageHeaders(); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java index 5b18f14cb5..bc454b92b2 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/RSocketRequester.java @@ -71,6 +71,20 @@ public interface RSocketRequester { return new DefaultRSocketRequesterBuilder(); } + /** + * Wrap an existing {@link RSocket}. Typically used in a client or server + * responder to wrap the remote {@code RSocket}. + * @param rsocket the RSocket to wrap + * @param dataMimeType the data MimeType, obtained from the + * {@link io.rsocket.ConnectionSetupPayload} (server) or the + * {@link io.rsocket.RSocketFactory.ClientRSocketFactory} (client) + * @param strategies the strategies to use + * @return the created RSocketRequester + */ + static RSocketRequester wrap(RSocket rsocket, @Nullable MimeType dataMimeType, RSocketStrategies strategies) { + return new DefaultRSocketRequester(rsocket, dataMimeType, strategies); + } + /** * Create a new {@code RSocketRequester} from the given {@link RSocket} and * strategies for encoding and decoding request and response payloads. @@ -78,7 +92,9 @@ public interface RSocketRequester { * @param dataMimeType the MimeType for data (from the SETUP frame) * @param strategies encoders, decoders, and others * @return the created RSocketRequester wrapper + * @deprecated use {@link #wrap(RSocket, MimeType, RSocketStrategies)} instead */ + @Deprecated static RSocketRequester create(RSocket rsocket, @Nullable MimeType dataMimeType, RSocketStrategies strategies) { return new DefaultRSocketRequester(rsocket, dataMimeType, strategies); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java index 75316c15af..74db02b343 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterTests.java @@ -70,7 +70,7 @@ public class DefaultRSocketRequesterTests { .encoder(CharSequenceEncoder.allMimeTypes()) .build(); this.rsocket = new TestRSocket(); - this.requester = RSocketRequester.create(rsocket, MimeTypeUtils.TEXT_PLAIN, strategies); + this.requester = RSocketRequester.wrap(this.rsocket, MimeTypeUtils.TEXT_PLAIN, strategies); }