diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java index 49462c3407..8485afa405 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java @@ -28,6 +28,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.Payload; import io.rsocket.metadata.CompositeMetadata; +import io.rsocket.metadata.WellKnownMimeType; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; @@ -153,7 +154,7 @@ public class DefaultMetadataExtractor implements MetadataExtractor { @Override public Map extract(Payload payload, MimeType metadataMimeType) { Map result = new HashMap<>(); - if (metadataMimeType.equals(COMPOSITE_METADATA)) { + if (metadataMimeType.toString().equals(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.toString())) { for (CompositeMetadata.Entry entry : new CompositeMetadata(payload.metadata(), false)) { extractEntry(entry.getContent(), entry.getMimeType(), result); } @@ -170,7 +171,7 @@ public class DefaultMetadataExtractor implements MetadataExtractor { extractor.extract(content, result); return; } - if (MetadataExtractor.ROUTING.toString().equals(mimeType)) { + if (mimeType != null && mimeType.equals(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString())) { // TODO: use rsocket-core API when available result.put(MetadataExtractor.ROUTE_KEY, content.toString(StandardCharsets.UTF_8)); } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java index 84029fccdd..4cc8b83afc 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilder.java @@ -27,6 +27,7 @@ import java.util.function.Consumer; import io.rsocket.Payload; import io.rsocket.RSocketFactory; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.metadata.WellKnownMimeType; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.client.WebsocketClientTransport; @@ -43,6 +44,7 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; /** * Default implementation of {@link RSocketRequester.Builder}. @@ -59,7 +61,8 @@ final class DefaultRSocketRequesterBuilder implements RSocketRequester.Builder { @Nullable private MimeType dataMimeType; - private MimeType metadataMimeType = MetadataExtractor.COMPOSITE_METADATA; + @Nullable + private MimeType metadataMimeType; @Nullable private Object setupData; @@ -159,11 +162,14 @@ final class DefaultRSocketRequesterBuilder implements RSocketRequester.Builder { factory.frameDecoder(PayloadDecoder.ZERO_COPY); } + MimeType metaMimeType = this.metadataMimeType != null ? this.metadataMimeType : + MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + MimeType dataMimeType = getDataMimeType(rsocketStrategies); factory.dataMimeType(dataMimeType.toString()); - factory.metadataMimeType(this.metadataMimeType.toString()); + factory.metadataMimeType(metaMimeType.toString()); - Payload setupPayload = getSetupPayload(dataMimeType, rsocketStrategies); + Payload setupPayload = getSetupPayload(dataMimeType, metaMimeType, rsocketStrategies); if (setupPayload != null) { factory.setupPayload(setupPayload); } @@ -171,14 +177,14 @@ final class DefaultRSocketRequesterBuilder implements RSocketRequester.Builder { return factory.transport(transport) .start() .map(rsocket -> new DefaultRSocketRequester( - rsocket, dataMimeType, this.metadataMimeType, rsocketStrategies)); + rsocket, dataMimeType, metaMimeType, rsocketStrategies)); } @Nullable - private Payload getSetupPayload(MimeType dataMimeType, RSocketStrategies strategies) { + private Payload getSetupPayload(MimeType dataMimeType, MimeType metaMimeType, RSocketStrategies strategies) { DataBuffer metadata = null; if (this.setupRoute != null || !CollectionUtils.isEmpty(this.setupMetadata)) { - metadata = new MetadataEncoder(this.metadataMimeType, strategies) + metadata = new MetadataEncoder(metaMimeType, strategies) .metadataAndOrRoute(this.setupMetadata, this.setupRoute, this.setupRouteVars) .encode(); } @@ -246,7 +252,7 @@ final class DefaultRSocketRequesterBuilder implements RSocketRequester.Builder { private static MimeType getMimeType(Decoder decoder) { MimeType mimeType = decoder.getDecodableMimeTypes().get(0); - return new MimeType(mimeType, Collections.emptyMap()); + return mimeType.getParameters().isEmpty() ? mimeType : new MimeType(mimeType, Collections.emptyMap()); } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataEncoder.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataEncoder.java index 745e66d472..f5f2e289f4 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataEncoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataEncoder.java @@ -25,6 +25,7 @@ import java.util.regex.Pattern; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; import io.rsocket.metadata.CompositeMetadataFlyweight; +import io.rsocket.metadata.WellKnownMimeType; import org.springframework.core.ResolvableType; import org.springframework.core.codec.Encoder; @@ -46,7 +47,7 @@ import org.springframework.util.ObjectUtils; final class MetadataEncoder { /** For route variable replacement. */ - private static final Pattern VARS_PATTERN = Pattern.compile("\\{([^/]+?)\\}"); + private static final Pattern VARS_PATTERN = Pattern.compile("\\{([^/]+?)}"); private final MimeType metadataMimeType; @@ -68,7 +69,8 @@ final class MetadataEncoder { Assert.notNull(strategies, "RSocketStrategies is required"); this.metadataMimeType = metadataMimeType; this.strategies = strategies; - this.isComposite = metadataMimeType.equals(MetadataExtractor.COMPOSITE_METADATA); + this.isComposite = this.metadataMimeType.toString().equals( + WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); this.allocator = bufferFactory() instanceof NettyDataBufferFactory ? ((NettyDataBufferFactory) bufferFactory()).getByteBufAllocator() : ByteBufAllocator.DEFAULT; } @@ -157,11 +159,15 @@ final class MetadataEncoder { * @see PayloadUtils#createPayload(DataBuffer, DataBuffer) */ public DataBuffer encode() { - Map mergedMetadata = mergeRouteAndMetadata(); if (this.isComposite) { CompositeByteBuf composite = this.allocator.compositeBuffer(); + if (this.route != null) { + CompositeMetadataFlyweight.encodeAndAddMetadata(composite, this.allocator, + WellKnownMimeType.MESSAGE_RSOCKET_ROUTING, + PayloadUtils.asByteBuf(bufferFactory().wrap(this.route.getBytes(StandardCharsets.UTF_8)))); + } try { - mergedMetadata.forEach((value, mimeType) -> { + this.metadata.forEach((value, mimeType) -> { DataBuffer buffer = encodeEntry(value, mimeType); CompositeMetadataFlyweight.encodeAndAddMetadata( composite, this.allocator, mimeType.toString(), PayloadUtils.asByteBuf(buffer)); @@ -180,38 +186,24 @@ final class MetadataEncoder { throw ex; } } + else if (this.route != null) { + Assert.isTrue(this.metadata.isEmpty(), "Composite metadata required for route and other entries"); + return this.metadataMimeType.toString().equals(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()) ? + bufferFactory().wrap(this.route.getBytes(StandardCharsets.UTF_8)) : + encodeEntry(this.route, this.metadataMimeType); + } else { - Assert.isTrue(mergedMetadata.size() == 1, "Composite metadata required for multiple entries"); - Map.Entry entry = mergedMetadata.entrySet().iterator().next(); + Assert.isTrue(this.metadata.size() == 1, "Composite metadata required for multiple entries"); + Map.Entry entry = this.metadata.entrySet().iterator().next(); if (!this.metadataMimeType.equals(entry.getValue())) { throw new IllegalArgumentException( "Connection configured for metadata mime type " + - "'" + this.metadataMimeType + "', but actual is `" + mergedMetadata + "`"); + "'" + this.metadataMimeType + "', but actual is `" + this.metadata + "`"); } return encodeEntry(entry.getKey(), entry.getValue()); } } - private Map mergeRouteAndMetadata() { - if (this.route == null) { - return this.metadata; - } - - MimeType routeMimeType = this.metadataMimeType.equals(MetadataExtractor.COMPOSITE_METADATA) ? - MetadataExtractor.ROUTING : this.metadataMimeType; - - Object routeValue = this.route; - if (routeMimeType.equals(MetadataExtractor.ROUTING)) { - // TODO: use rsocket-core API when available - routeValue = bufferFactory().wrap(this.route.getBytes(StandardCharsets.UTF_8)); - } - - Map result = new LinkedHashMap<>(); - result.put(routeValue, routeMimeType); - result.putAll(this.metadata); - return result; - } - @SuppressWarnings("unchecked") private DataBuffer encodeEntry(Object metadata, MimeType mimeType) { if (metadata instanceof DataBuffer) { diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java index ca617c088e..41a9204e38 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/MetadataExtractor.java @@ -38,16 +38,6 @@ public interface MetadataExtractor { */ String ROUTE_KEY = "route"; - /** - * Constant MimeType {@code "message/x.rsocket.composite-metadata.v0"}. - */ - MimeType COMPOSITE_METADATA = new MimeType("message", "x.rsocket.composite-metadata.v0"); - - /** - * Constant for MimeType {@code "message/x.rsocket.routing.v0"}. - */ - MimeType ROUTING = new MimeType("message", "x.rsocket.routing.v0"); - /** * Extract a map of values from the given {@link Payload} metadata. diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java index cc83c7f510..7ee94e85c6 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandler.java @@ -24,6 +24,7 @@ import io.rsocket.ConnectionSetupPayload; import io.rsocket.RSocket; import io.rsocket.SocketAcceptor; import io.rsocket.frame.FrameType; +import io.rsocket.metadata.WellKnownMimeType; import reactor.core.publisher.Mono; import org.springframework.beans.BeanUtils; @@ -73,7 +74,8 @@ public class RSocketMessageHandler extends MessageMappingMessageHandler { @Nullable private MimeType defaultDataMimeType; - private MimeType defaultMetadataMimeType = MetadataExtractor.COMPOSITE_METADATA; + private MimeType defaultMetadataMimeType = MimeTypeUtils.parseMimeType( + WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); public RSocketMessageHandler() { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java index a95eef7589..cd1052a4ae 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java @@ -21,6 +21,7 @@ import java.util.Map; import io.netty.buffer.PooledByteBufAllocator; import io.rsocket.Payload; +import io.rsocket.metadata.WellKnownMimeType; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -30,12 +31,12 @@ import org.springframework.core.codec.StringDecoder; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.springframework.messaging.rsocket.MetadataExtractor.COMPOSITE_METADATA; import static org.springframework.messaging.rsocket.MetadataExtractor.ROUTE_KEY; -import static org.springframework.messaging.rsocket.MetadataExtractor.ROUTING; import static org.springframework.util.MimeTypeUtils.TEXT_HTML; import static org.springframework.util.MimeTypeUtils.TEXT_PLAIN; import static org.springframework.util.MimeTypeUtils.TEXT_XML; @@ -47,6 +48,10 @@ import static org.springframework.util.MimeTypeUtils.TEXT_XML; */ public class DefaultMetadataExtractorTests { + private static MimeType COMPOSITE_METADATA = + MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + + private RSocketStrategies strategies; private DefaultMetadataExtractor extractor; @@ -108,10 +113,11 @@ public class DefaultMetadataExtractorTests { @Test public void route() { - MetadataEncoder metadataEncoder = new MetadataEncoder(ROUTING, this.strategies).route("toA"); + MimeType metaMimeType = MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + MetadataEncoder metadataEncoder = new MetadataEncoder(metaMimeType, this.strategies).route("toA"); DataBuffer metadata = metadataEncoder.encode(); Payload payload = createPayload(metadata); - Map result = this.extractor.extract(payload, ROUTING); + Map result = this.extractor.extract(payload, metaMimeType); payload.release(); assertThat(result).hasSize(1).containsEntry(ROUTE_KEY, "toA"); diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilderTests.java index cc3f0f98eb..80761f56bc 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilderTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultRSocketRequesterBuilderTests.java @@ -28,6 +28,7 @@ import io.rsocket.ConnectionSetupPayload; import io.rsocket.DuplexConnection; import io.rsocket.RSocketFactory; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.metadata.WellKnownMimeType; import io.rsocket.transport.ClientTransport; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -150,7 +151,7 @@ public class DefaultRSocketRequesterBuilderTests { @Test public void mimeTypesCannotBeChangedAtRSocketFactoryLevel() { MimeType dataMimeType = MimeTypeUtils.APPLICATION_JSON; - MimeType metaMimeType = MetadataExtractor.ROUTING; + MimeType metaMimeType = MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); RSocketRequester requester = RSocketRequester.builder() .metadataMimeType(metaMimeType) diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/MetadataEncoderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/MetadataEncoderTests.java index ab7c37d6ea..a664e685fd 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/MetadataEncoderTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/MetadataEncoderTests.java @@ -23,6 +23,7 @@ import java.util.Map; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.metadata.CompositeMetadata; +import io.rsocket.metadata.WellKnownMimeType; import org.junit.jupiter.api.Test; import org.springframework.core.io.buffer.DataBuffer; @@ -43,12 +44,16 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; */ public class MetadataEncoderTests { + private static MimeType COMPOSITE_METADATA = + MimeTypeUtils.parseMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()); + + private final RSocketStrategies strategies = RSocketStrategies.create(); @Test public void compositeMetadataWithRoute() { - DataBuffer buffer = new MetadataEncoder(MetadataExtractor.COMPOSITE_METADATA, this.strategies) + DataBuffer buffer = new MetadataEncoder(COMPOSITE_METADATA, this.strategies) .route("toA") .encode(); @@ -57,7 +62,7 @@ public class MetadataEncoderTests { assertThat(iterator.hasNext()).isTrue(); CompositeMetadata.Entry entry = iterator.next(); - assertThat(entry.getMimeType()).isEqualTo(MetadataExtractor.ROUTING.toString()); + assertThat(entry.getMimeType()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); assertThat(entry.getContent().toString(StandardCharsets.UTF_8)).isEqualTo("toA"); assertThat(iterator.hasNext()).isFalse(); @@ -66,7 +71,7 @@ public class MetadataEncoderTests { @Test public void compositeMetadataWithRouteAndText() { - DataBuffer buffer = new MetadataEncoder(MetadataExtractor.COMPOSITE_METADATA, this.strategies) + DataBuffer buffer = new MetadataEncoder(COMPOSITE_METADATA, this.strategies) .route("toA") .metadata("My metadata", MimeTypeUtils.TEXT_PLAIN) .encode(); @@ -76,7 +81,7 @@ public class MetadataEncoderTests { assertThat(iterator.hasNext()).isTrue(); CompositeMetadata.Entry entry = iterator.next(); - assertThat(entry.getMimeType()).isEqualTo(MetadataExtractor.ROUTING.toString()); + assertThat(entry.getMimeType()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); assertThat(entry.getContent().toString(StandardCharsets.UTF_8)).isEqualTo("toA"); assertThat(iterator.hasNext()).isTrue(); @@ -89,8 +94,11 @@ public class MetadataEncoderTests { @Test public void routeWithRoutingMimeType() { + MimeType metaMimeType = MimeTypeUtils.parseMimeType( + WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + DataBuffer buffer = - new MetadataEncoder(MetadataExtractor.ROUTING, this.strategies) + new MetadataEncoder(metaMimeType, this.strategies) .route("toA") .encode(); @@ -154,7 +162,7 @@ public class MetadataEncoderTests { @Test public void mimeTypeRequiredForCompositeEntries() { - MetadataEncoder encoder = new MetadataEncoder(MetadataExtractor.COMPOSITE_METADATA, this.strategies); + MetadataEncoder encoder = new MetadataEncoder(COMPOSITE_METADATA, this.strategies); assertThatThrownBy(() -> encoder.metadata("toA", null)) .hasMessage("MimeType is required for composite metadata entries."); @@ -174,7 +182,7 @@ public class MetadataEncoderTests { DefaultDataBufferFactory bufferFactory = new DefaultDataBufferFactory(); RSocketStrategies strategies = RSocketStrategies.builder().dataBufferFactory(bufferFactory).build(); - DataBuffer buffer = new MetadataEncoder(MetadataExtractor.COMPOSITE_METADATA, strategies) + DataBuffer buffer = new MetadataEncoder(COMPOSITE_METADATA, strategies) .route("toA") .encode(); @@ -187,7 +195,7 @@ public class MetadataEncoderTests { assertThat(iterator.hasNext()).isTrue(); CompositeMetadata.Entry entry = iterator.next(); - assertThat(entry.getMimeType()).isEqualTo(MetadataExtractor.ROUTING.toString()); + assertThat(entry.getMimeType()).isEqualTo(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); assertThat(entry.getContent().toString(StandardCharsets.UTF_8)).isEqualTo("toA"); assertThat(iterator.hasNext()).isFalse();