diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java index 1d19d9a784..6e320bd5d0 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultMetadataExtractor.java @@ -17,6 +17,7 @@ package org.springframework.messaging.rsocket; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -34,7 +35,6 @@ import org.springframework.core.codec.Decoder; import org.springframework.core.io.buffer.NettyDataBuffer; import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; import org.springframework.util.MimeType; /** @@ -51,46 +51,28 @@ import org.springframework.util.MimeType; */ public class DefaultMetadataExtractor implements MetadataExtractor { - private final List> decoders = new ArrayList<>(); + private final List> decoders; - private final Map> processors = new HashMap<>(); + private final Map> registrations = new HashMap<>(); /** - * Configure the decoders to use for de-serializing metadata entries. - *

By default this is not set. + * Constructor with decoders for de-serializing metadata entries. */ - public void setDecoders(List> decoders) { - this.decoders.clear(); - if (!decoders.isEmpty()) { - this.decoders.addAll(decoders); - updateProcessors(); - } + public DefaultMetadataExtractor(Decoder... decoders) { + this(Arrays.asList(decoders)); } - @SuppressWarnings("unchecked") - private void updateProcessors() { - for (MetadataProcessor info : this.processors.values()) { - Decoder decoder = decoderFor(info.mimeType(), info.targetType()); - Assert.isTrue(decoder != null, "No decoder for " + info); - info = ((MetadataProcessor) info).setDecoder(decoder); - this.processors.put(info.mimeType().toString(), info); - } + /** + * Constructor with list of decoders for de-serializing metadata entries. + */ + public DefaultMetadataExtractor(List> decoders) { + this.decoders = Collections.unmodifiableList(new ArrayList<>(decoders)); } - @Nullable - @SuppressWarnings("unchecked") - private Decoder decoderFor(MimeType mimeType, ResolvableType type) { - for (Decoder decoder : this.decoders) { - if (decoder.canDecode(type, mimeType)) { - return (Decoder) decoder; - } - } - return null; - } /** - * Return the {@link #setDecoders(List) configured} decoders. + * Return a read-only list with the configured decoders. */ public List> getDecoders() { return this.decoders; @@ -106,9 +88,7 @@ public class DefaultMetadataExtractor implements MetadataExtractor { * @param name assign a name for the decoded value; if not provided, then * the mime type is used as the key */ - public void metadataToExtract( - MimeType mimeType, Class targetType, @Nullable String name) { - + public void metadataToExtract(MimeType mimeType, Class targetType, @Nullable String name) { String key = name != null ? name : mimeType.toString(); metadataToExtract(mimeType, targetType, (value, map) -> map.put(key, value)); } @@ -117,6 +97,8 @@ public class DefaultMetadataExtractor implements MetadataExtractor { * Variant of {@link #metadataToExtract(MimeType, Class, String)} that accepts * {@link ParameterizedTypeReference} instead of {@link Class} for * specifying a target type with generic parameters. + * @param mimeType the mime type of metadata entries to extract + * @param targetType the target value type to decode to */ public void metadataToExtract( MimeType mimeType, ParameterizedTypeReference targetType, @Nullable String name) { @@ -137,7 +119,7 @@ public class DefaultMetadataExtractor implements MetadataExtractor { public void metadataToExtract( MimeType mimeType, Class targetType, BiConsumer> mapper) { - metadataToExtract(mimeType, mapper, ResolvableType.forClass(targetType)); + registerMetadata(mimeType, ResolvableType.forClass(targetType), mapper); } /** @@ -145,24 +127,28 @@ public class DefaultMetadataExtractor implements MetadataExtractor { * accepts {@link ParameterizedTypeReference} instead of {@link Class} for * specifying a target type with generic parameters. * @param mimeType the mime type of metadata entries to extract - * @param targetType the target value type to decode to + * @param type the target value type to decode to * @param mapper custom logic to add the decoded value to the output map * @param the target value type */ public void metadataToExtract( - MimeType mimeType, ParameterizedTypeReference targetType, - BiConsumer> mapper) { + MimeType mimeType, ParameterizedTypeReference type, BiConsumer> mapper) { - metadataToExtract(mimeType, mapper, ResolvableType.forType(targetType)); + registerMetadata(mimeType, ResolvableType.forType(type), mapper); } - private void metadataToExtract( - MimeType mimeType, BiConsumer> mapper, ResolvableType elementType) { + @SuppressWarnings("unchecked") + private void registerMetadata( + MimeType mimeType, ResolvableType targetType, BiConsumer> mapper) { - Decoder decoder = decoderFor(mimeType, elementType); - Assert.isTrue(this.decoders.isEmpty() || decoder != null, () -> "No decoder for " + mimeType); - MetadataProcessor info = new MetadataProcessor<>(mimeType, elementType, mapper, decoder); - this.processors.put(mimeType.toString(), info); + for (Decoder decoder : this.decoders) { + if (decoder.canDecode(targetType, mimeType)) { + this.registrations.put(mimeType.toString(), + new EntryExtractor<>((Decoder) decoder, mimeType, targetType, mapper)); + return; + } + } + throw new IllegalArgumentException("No decoder for " + mimeType + " and " + targetType); } @@ -171,20 +157,19 @@ public class DefaultMetadataExtractor implements MetadataExtractor { Map result = new HashMap<>(); if (metadataMimeType.equals(COMPOSITE_METADATA)) { for (CompositeMetadata.Entry entry : new CompositeMetadata(payload.metadata(), false)) { - processEntry(entry.getContent(), entry.getMimeType(), result); + extractEntry(entry.getContent(), entry.getMimeType(), result); } } else { - processEntry(payload.metadata(), metadataMimeType.toString(), result); + extractEntry(payload.metadata(), metadataMimeType.toString(), result); } return result; } - @SuppressWarnings("unchecked") - private void processEntry(ByteBuf content, @Nullable String mimeType, Map result) { - MetadataProcessor info = (MetadataProcessor) this.processors.get(mimeType); - if (info != null) { - info.process(content, result); + private void extractEntry(ByteBuf content, @Nullable String mimeType, Map result) { + EntryExtractor extractor = this.registrations.get(mimeType); + if (extractor != null) { + extractor.extract(content, result); return; } if (MetadataExtractor.ROUTING.toString().equals(mimeType)) { @@ -194,56 +179,32 @@ public class DefaultMetadataExtractor implements MetadataExtractor { } - private static class MetadataProcessor { + private static class EntryExtractor { private final static NettyDataBufferFactory bufferFactory = new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT); + private final Decoder decoder; + private final MimeType mimeType; private final ResolvableType targetType; private final BiConsumer> accumulator; - @Nullable - private final Decoder decoder; - - MetadataProcessor(MimeType mimeType, ResolvableType targetType, - BiConsumer> accumulator, @Nullable Decoder decoder) { + EntryExtractor(Decoder decoder, MimeType mimeType, ResolvableType targetType, + BiConsumer> accumulator) { + this.decoder = decoder; this.mimeType = mimeType; this.targetType = targetType; this.accumulator = accumulator; - this.decoder = decoder; - } - - MetadataProcessor(MetadataProcessor other, Decoder decoder) { - this.mimeType = other.mimeType; - this.targetType = other.targetType; - this.accumulator = other.accumulator; - this.decoder = decoder; } - public MimeType mimeType() { - return this.mimeType; - } - - public ResolvableType targetType() { - return this.targetType; - } - - public MetadataProcessor setDecoder(Decoder decoder) { - return this.decoder != decoder ? new MetadataProcessor<>(this, decoder) : this; - } - - - public void process(ByteBuf content, Map result) { - if (this.decoder == null) { - throw new IllegalStateException("No decoder for " + this); - } + public void extract(ByteBuf content, Map result) { NettyDataBuffer dataBuffer = bufferFactory.wrap(content.retain()); T value = this.decoder.decode(dataBuffer, this.targetType, this.mimeType, Collections.emptyMap()); this.accumulator.accept(value, result); @@ -252,7 +213,7 @@ public class DefaultMetadataExtractor implements MetadataExtractor { @Override public String toString() { - return "MetadataProcessor mimeType=" + this.mimeType + ", targetType=" + this.targetType; + return "mimeType=" + this.mimeType + ", targetType=" + this.targetType; } } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketStrategies.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketStrategies.java index 28df18aa7d..21435e0544 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketStrategies.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/DefaultRSocketStrategies.java @@ -205,12 +205,20 @@ final class DefaultRSocketStrategies implements RSocketStrategies { @Override public RSocketStrategies build() { + + RouteMatcher matcher = this.routeMatcher != null ? this.routeMatcher : initRouteMatcher(); + + MetadataExtractor extractor = this.metadataExtractor != null ? + this.metadataExtractor : new DefaultMetadataExtractor(this.decoders); + + DataBufferFactory factory = this.bufferFactory != null ? + this.bufferFactory : new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT); + + ReactiveAdapterRegistry registry = this.adapterRegistry != null ? + this.adapterRegistry : ReactiveAdapterRegistry.getSharedInstance(); + return new DefaultRSocketStrategies( - this.encoders, this.decoders, - this.routeMatcher != null ? this.routeMatcher : initRouteMatcher(), - this.metadataExtractor != null ? this.metadataExtractor : initMetadataExtractor(), - this.bufferFactory != null ? this.bufferFactory : initBufferFactory(), - this.adapterRegistry != null ? this.adapterRegistry : ReactiveAdapterRegistry.getSharedInstance()); + this.encoders, this.decoders, matcher, extractor, factory, registry); } private RouteMatcher initRouteMatcher() { @@ -218,16 +226,6 @@ final class DefaultRSocketStrategies implements RSocketStrategies { pathMatcher.setPathSeparator("."); return new SimpleRouteMatcher(pathMatcher); } - - private MetadataExtractor initMetadataExtractor() { - DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); - extractor.setDecoders(this.decoders); - return extractor; - } - - private DataBufferFactory initBufferFactory() { - return new NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT); - } } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java index bad58c8ac8..438087fa97 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/DefaultMetadataExtractorTests.java @@ -37,7 +37,6 @@ import org.springframework.util.MimeType; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.springframework.messaging.rsocket.MetadataExtractor.COMPOSITE_METADATA; import static org.springframework.messaging.rsocket.MetadataExtractor.ROUTE_KEY; import static org.springframework.messaging.rsocket.MetadataExtractor.ROUTING; @@ -71,8 +70,7 @@ public class DefaultMetadataExtractorTests { this.captor = ArgumentCaptor.forClass(Payload.class); BDDMockito.when(this.rsocket.fireAndForget(captor.capture())).thenReturn(Mono.empty()); - this.extractor = new DefaultMetadataExtractor(); - this.extractor.setDecoders(Collections.singletonList(StringDecoder.allMimeTypes())); + this.extractor = new DefaultMetadataExtractor(StringDecoder.allMimeTypes()); } @After @@ -165,56 +163,14 @@ public class DefaultMetadataExtractorTests { } @Test - public void addMetadataToExtractBeforeDecoders() { - DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); - extractor.metadataToExtract(TEXT_PLAIN, String.class, "key"); - extractor.setDecoders(Collections.singletonList(StringDecoder.allMimeTypes())); - - requester(TEXT_PLAIN).metadata("meta entry", null).data("data").send().block(); - Payload payload = this.captor.getValue(); - Map result = extractor.extract(payload, TEXT_PLAIN); - payload.release(); - - assertThat(result).hasSize(1).containsEntry("key", "meta entry"); - } - - @Test - public void noDecoderExceptionWhenSettingDecoders() { - DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); - extractor.metadataToExtract(TEXT_PLAIN, String.class, "key"); - - assertThatIllegalArgumentException() - .isThrownBy(() -> extractor.setDecoders(Collections.singletonList(new ByteArrayDecoder()))) - .withMessage("No decoder for MetadataProcessor mimeType=text/plain, targetType=java.lang.String"); - } - - @Test - public void noDecoderExceptionWhenRegisteringMetadataToExtract() { - DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); - extractor.setDecoders(Collections.singletonList(new ByteArrayDecoder())); + public void noDecoder() { + DefaultMetadataExtractor extractor = + new DefaultMetadataExtractor(Collections.singletonList(new ByteArrayDecoder()) + ); assertThatIllegalArgumentException() .isThrownBy(() -> extractor.metadataToExtract(TEXT_PLAIN, String.class, "key")) - .withMessage("No decoder for text/plain"); - } - - @Test - public void decodersNotSet() { - DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); - extractor.metadataToExtract(TEXT_PLAIN, String.class, "key"); - - assertThatIllegalStateException() - .isThrownBy(() -> { - requester(TEXT_PLAIN).metadata("meta entry", null).data("data").send().block(); - Payload payload = this.captor.getValue(); - try { - extractor.extract(payload, TEXT_PLAIN); - } - finally { - payload.release(); - } - }) - .withMessage("No decoder for MetadataProcessor mimeType=text/plain, targetType=java.lang.String"); + .withMessage("No decoder for text/plain and java.lang.String"); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketServerToClientIntegrationTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketServerToClientIntegrationTests.java index ef046794e9..dc443b6aa7 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketServerToClientIntegrationTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/RSocketServerToClientIntegrationTests.java @@ -17,7 +17,6 @@ package org.springframework.messaging.rsocket; import java.time.Duration; -import java.util.Collections; import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; @@ -267,8 +266,7 @@ public class RSocketServerToClientIntegrationTests { @Bean public RSocketStrategies rsocketStrategies() { - DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); - extractor.setDecoders(Collections.singletonList(StringDecoder.allMimeTypes())); + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(StringDecoder.allMimeTypes()); extractor.metadataToExtract(MimeTypeUtils.TEXT_PLAIN, String.class, MetadataExtractor.ROUTE_KEY); return RSocketStrategies.builder().metadataExtractor(extractor).build(); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandlerTests.java b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandlerTests.java index 21b476c331..91f0917842 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandlerTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/rsocket/annotation/support/RSocketMessageHandlerTests.java @@ -129,8 +129,7 @@ public class RSocketMessageHandlerTests { @Test public void metadataExtractorWithExplicitlySetDecoders() { - DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(); - extractor.setDecoders(Collections.singletonList(StringDecoder.allMimeTypes())); + DefaultMetadataExtractor extractor = new DefaultMetadataExtractor(StringDecoder.allMimeTypes()); RSocketMessageHandler handler = new RSocketMessageHandler(); handler.setDecoders(Arrays.asList(new ByteArrayDecoder(), new ByteBufferDecoder()));