diff --git a/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurer.java b/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurer.java index 4b9487aefe..f4d8a3941c 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurer.java +++ b/spring-web/src/main/java/org/springframework/http/codec/CodecConfigurer.java @@ -127,6 +127,22 @@ public interface CodecConfigurer { */ void protobufEncoder(Encoder encoder); + /** + * Override the default JAXB2 {@code Decoder}. + * @param decoder the decoder instance to use + * @since 5.1.3 + * @see org.springframework.http.codec.xml.Jaxb2XmlDecoder + */ + void jaxb2Decoder(Decoder decoder); + + /** + * Override the default JABX2 {@code Encoder}. + * @param encoder the encoder instance to use + * @since 5.1.3 + * @see org.springframework.http.codec.xml.Jaxb2XmlEncoder + */ + void jaxb2Encoder(Encoder encoder); + /** * Whether to log form data at DEBUG level, and headers at TRACE level. * Both may contain sensitive information. diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java b/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java index 2634f862ca..19529c004e 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java +++ b/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java @@ -89,6 +89,12 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { @Nullable private Encoder protobufEncoder; + @Nullable + private Decoder jaxb2Decoder; + + @Nullable + private Encoder jaxb2Encoder; + private boolean enableLoggingRequestDetails = false; private boolean registerDefaults = true; @@ -114,6 +120,16 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { this.protobufEncoder = encoder; } + @Override + public void jaxb2Decoder(Decoder decoder) { + this.jaxb2Decoder = decoder; + } + + @Override + public void jaxb2Encoder(Encoder encoder) { + this.jaxb2Encoder = encoder; + } + @Override public void enableLoggingRequestDetails(boolean enable) { this.enableLoggingRequestDetails = enable; @@ -145,7 +161,8 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { readers.add(new DecoderHttpMessageReader<>(new ResourceDecoder())); readers.add(new DecoderHttpMessageReader<>(StringDecoder.textPlainOnly())); if (protobufPresent) { - readers.add(new DecoderHttpMessageReader<>(getProtobufDecoder())); + Decoder decoder = this.protobufDecoder != null ? this.protobufDecoder : new ProtobufDecoder(); + readers.add(new DecoderHttpMessageReader<>(decoder)); } FormHttpMessageReader formReader = new FormHttpMessageReader(); @@ -178,7 +195,8 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { readers.add(new DecoderHttpMessageReader<>(new Jackson2SmileDecoder())); } if (jaxb2Present) { - readers.add(new DecoderHttpMessageReader<>(new Jaxb2XmlDecoder())); + Decoder decoder = this.jaxb2Decoder != null ? this.jaxb2Decoder : new Jaxb2XmlDecoder(); + readers.add(new DecoderHttpMessageReader<>(decoder)); } extendObjectReaders(readers); return readers; @@ -224,7 +242,8 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { extendTypedWriters(writers); } if (protobufPresent) { - writers.add(new ProtobufHttpMessageWriter((Encoder) getProtobufEncoder())); + Encoder encoder = this.protobufEncoder != null ? this.protobufEncoder : new ProtobufEncoder(); + writers.add(new ProtobufHttpMessageWriter((Encoder) encoder)); } return writers; } @@ -253,7 +272,8 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { writers.add(new EncoderHttpMessageWriter<>(new Jackson2SmileEncoder())); } if (jaxb2Present) { - writers.add(new EncoderHttpMessageWriter<>(new Jaxb2XmlEncoder())); + Encoder encoder = this.jaxb2Encoder != null ? this.jaxb2Encoder : new Jaxb2XmlEncoder(); + writers.add(new EncoderHttpMessageWriter<>(encoder)); } // No client or server specific multipart writers currently.. if (!forMultipart) { @@ -291,12 +311,4 @@ class BaseDefaultCodecs implements CodecConfigurer.DefaultCodecs { return (this.jackson2JsonEncoder != null ? this.jackson2JsonEncoder : new Jackson2JsonEncoder()); } - protected Decoder getProtobufDecoder() { - return (this.protobufDecoder != null ? this.protobufDecoder : new ProtobufDecoder()); - } - - protected Encoder getProtobufEncoder() { - return (this.protobufEncoder != null ? this.protobufEncoder : new ProtobufEncoder()); - } - } diff --git a/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlDecoder.java b/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlDecoder.java index 0ed9580904..1b5b15e747 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlDecoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlDecoder.java @@ -75,12 +75,32 @@ public class Jaxb2XmlDecoder extends AbstractDecoder { private final JaxbContextContainer jaxbContexts = new JaxbContextContainer(); + private Function unmarshallerProcessor = Function.identity(); + public Jaxb2XmlDecoder() { super(MimeTypeUtils.APPLICATION_XML, MimeTypeUtils.TEXT_XML); } + /** + * Configure a processor function to customize Unmarshaller instances. + * @param processor the function to use + * @since 5.1.3 + */ + public void setUnmarshallerProcessor(Function processor) { + this.unmarshallerProcessor = this.unmarshallerProcessor.andThen(processor); + } + + /** + * Return the configured processor for customizing Unmarshaller instances. + * @since 5.1.3 + */ + public Function getUnmarshallerProcessor() { + return this.unmarshallerProcessor; + } + + @Override public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) { if (super.canDecode(elementType, mimeType)) { @@ -123,7 +143,7 @@ public class Jaxb2XmlDecoder extends AbstractDecoder { private Object unmarshal(List events, Class outputClass) { try { - Unmarshaller unmarshaller = this.jaxbContexts.createUnmarshaller(outputClass); + Unmarshaller unmarshaller = initUnmarshaller(outputClass); XMLEventReader eventReader = StaxUtils.createXMLEventReader(events); if (outputClass.isAnnotationPresent(XmlRootElement.class)) { return unmarshaller.unmarshal(eventReader); @@ -141,6 +161,11 @@ public class Jaxb2XmlDecoder extends AbstractDecoder { } } + private Unmarshaller initUnmarshaller(Class outputClass) throws JAXBException { + Unmarshaller unmarshaller = this.jaxbContexts.createUnmarshaller(outputClass); + return this.unmarshallerProcessor.apply(unmarshaller); + } + /** * Returns the qualified name for the given class, according to the mapping rules * in the JAXB specification. diff --git a/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlEncoder.java index 905c60e83d..9562936796 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlEncoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/xml/Jaxb2XmlEncoder.java @@ -19,6 +19,7 @@ package org.springframework.http.codec.xml; import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.util.Map; +import java.util.function.Function; import javax.xml.bind.JAXBException; import javax.xml.bind.MarshalException; import javax.xml.bind.Marshaller; @@ -57,12 +58,32 @@ public class Jaxb2XmlEncoder extends AbstractSingleValueEncoder { private final JaxbContextContainer jaxbContexts = new JaxbContextContainer(); + private Function marshallerProcessor = Function.identity(); + public Jaxb2XmlEncoder() { super(MimeTypeUtils.APPLICATION_XML, MimeTypeUtils.TEXT_XML); } + /** + * Configure a processor function to customize Marshaller instances. + * @param processor the function to use + * @since 5.1.3 + */ + public void setMarshallerProcessor(Function processor) { + this.marshallerProcessor = this.marshallerProcessor.andThen(processor); + } + + /** + * Return the configured processor for customizing Marshaller instances. + * @since 5.1.3 + */ + public Function getMarshallerProcessor() { + return this.marshallerProcessor; + } + + @Override public boolean canEncode(ResolvableType elementType, @Nullable MimeType mimeType) { if (super.canEncode(elementType, mimeType)) { @@ -92,8 +113,7 @@ public class Jaxb2XmlEncoder extends AbstractSingleValueEncoder { Class clazz = ClassUtils.getUserClass(value); try { - Marshaller marshaller = this.jaxbContexts.createMarshaller(clazz); - marshaller.setProperty(Marshaller.JAXB_ENCODING, StandardCharsets.UTF_8.name()); + Marshaller marshaller = initMarshaller(clazz); marshaller.marshal(value, outputStream); release = false; return Flux.just(buffer); @@ -112,4 +132,11 @@ public class Jaxb2XmlEncoder extends AbstractSingleValueEncoder { } } + private Marshaller initMarshaller(Class clazz) throws JAXBException { + Marshaller marshaller = this.jaxbContexts.createMarshaller(clazz); + marshaller.setProperty(Marshaller.JAXB_ENCODING, StandardCharsets.UTF_8.name()); + marshaller = this.marshallerProcessor.apply(marshaller); + return marshaller; + } + } diff --git a/spring-web/src/test/java/org/springframework/http/codec/support/CodecConfigurerTests.java b/spring-web/src/test/java/org/springframework/http/codec/support/CodecConfigurerTests.java index 320ee4ea35..3f7057eea9 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/support/CodecConfigurerTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/support/CodecConfigurerTests.java @@ -244,58 +244,29 @@ public class CodecConfigurerTests { } @Test - public void jackson2DecoderOverride() { - Jackson2JsonDecoder decoder = new Jackson2JsonDecoder(); - this.configurer.defaultCodecs().jackson2JsonDecoder(decoder); - - assertSame(decoder, this.configurer.getReaders().stream() - .filter(writer -> writer instanceof DecoderHttpMessageReader) - .map(writer -> ((DecoderHttpMessageReader) writer).getDecoder()) - .filter(e -> Jackson2JsonDecoder.class.equals(e.getClass())) - .findFirst() - .filter(e -> e == decoder).orElse(null)); - } - - @Test - public void jackson2EncoderOverride() { - Jackson2JsonEncoder encoder = new Jackson2JsonEncoder(); - this.configurer.defaultCodecs().jackson2JsonEncoder(encoder); - - assertSame(encoder, this.configurer.getWriters().stream() - .filter(writer -> writer instanceof EncoderHttpMessageWriter) - .map(writer -> ((EncoderHttpMessageWriter) writer).getEncoder()) - .filter(e -> Jackson2JsonEncoder.class.equals(e.getClass())) - .findFirst() - .filter(e -> e == encoder).orElse(null)); - } - - @Test - public void protobufDecoderOverride() { - ProtobufDecoder decoder = new ProtobufDecoder(ExtensionRegistry.newInstance()); - this.configurer.defaultCodecs().protobufDecoder(decoder); - - assertSame(decoder, this.configurer.getReaders().stream() - .filter(writer -> writer instanceof DecoderHttpMessageReader) - .map(writer -> ((DecoderHttpMessageReader) writer).getDecoder()) - .filter(e -> ProtobufDecoder.class.equals(e.getClass())) - .findFirst() - .filter(e -> e == decoder).orElse(null)); + public void encoderDecoderOverrides() { + Jackson2JsonDecoder jacksonDecoder = new Jackson2JsonDecoder(); + Jackson2JsonEncoder jacksonEncoder = new Jackson2JsonEncoder(); + ProtobufDecoder protobufDecoder = new ProtobufDecoder(ExtensionRegistry.newInstance()); + ProtobufEncoder protobufEncoder = new ProtobufEncoder(); + Jaxb2XmlEncoder jaxb2Encoder = new Jaxb2XmlEncoder(); + Jaxb2XmlDecoder jaxb2Decoder = new Jaxb2XmlDecoder(); + + this.configurer.defaultCodecs().jackson2JsonDecoder(jacksonDecoder); + this.configurer.defaultCodecs().jackson2JsonEncoder(jacksonEncoder); + this.configurer.defaultCodecs().protobufDecoder(protobufDecoder); + this.configurer.defaultCodecs().protobufEncoder(protobufEncoder); + this.configurer.defaultCodecs().jaxb2Decoder(jaxb2Decoder); + this.configurer.defaultCodecs().jaxb2Encoder(jaxb2Encoder); + + assertDecoderInstance(jacksonDecoder); + assertDecoderInstance(protobufDecoder); + assertDecoderInstance(jaxb2Decoder); + assertEncoderInstance(jacksonEncoder); + assertEncoderInstance(protobufEncoder); + assertEncoderInstance(jaxb2Encoder); } - @Test - public void protobufEncoderOverride() { - ProtobufEncoder encoder = new ProtobufEncoder(); - this.configurer.defaultCodecs().protobufEncoder(encoder); - - assertSame(encoder, this.configurer.getWriters().stream() - .filter(writer -> writer instanceof EncoderHttpMessageWriter) - .map(writer -> ((EncoderHttpMessageWriter) writer).getEncoder()) - .filter(e -> ProtobufEncoder.class.equals(e.getClass())) - .findFirst() - .filter(e -> e == encoder).orElse(null)); - } - - private Decoder getNextDecoder(List> readers) { HttpMessageReader reader = readers.get(this.index.getAndIncrement()); assertEquals(DecoderHttpMessageReader.class, reader.getClass()); @@ -320,6 +291,23 @@ public class CodecConfigurerTests { assertEquals(!textOnly, encoder.canEncode(ResolvableType.forClass(String.class), MediaType.TEXT_EVENT_STREAM)); } + private void assertDecoderInstance(Decoder decoder) { + assertSame(decoder, this.configurer.getReaders().stream() + .filter(writer -> writer instanceof DecoderHttpMessageReader) + .map(writer -> ((DecoderHttpMessageReader) writer).getDecoder()) + .filter(e -> decoder.getClass().equals(e.getClass())) + .findFirst() + .filter(e -> e == decoder).orElse(null)); + } + + private void assertEncoderInstance(Encoder encoder) { + assertSame(encoder, this.configurer.getWriters().stream() + .filter(writer -> writer instanceof EncoderHttpMessageWriter) + .map(writer -> ((EncoderHttpMessageWriter) writer).getEncoder()) + .filter(e -> encoder.getClass().equals(e.getClass())) + .findFirst() + .filter(e -> e == encoder).orElse(null)); + } private static class TestCodecConfigurer extends BaseCodecConfigurer {