diff --git a/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java index 149d0d2eaf..07be82ad60 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/support/MessageHeaderAccessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -136,19 +136,15 @@ public class MessageHeaderAccessor { * A constructor to create new headers. */ public MessageHeaderAccessor() { - this.headers = new MutableMessageHeaders(); + this(null); } /** * A constructor accepting the headers of an existing message to copy. + * @param message a message to copy the headers from, or {@code null} if none */ public MessageHeaderAccessor(Message message) { - if (message != null) { - this.headers = new MutableMessageHeaders(message.getHeaders()); - } - else { - this.headers = new MutableMessageHeaders(); - } + this.headers = new MutableMessageHeaders(message != null ? message.getHeaders() : null); } @@ -193,7 +189,6 @@ public class MessageHeaderAccessor { * @since 4.1 */ public void setImmutable() { - this.headers.setIdAndTimestamp(); this.headers.setImmutable(); } @@ -583,7 +578,7 @@ public class MessageHeaderAccessor { if (messageHeaders instanceof MutableMessageHeaders) { MutableMessageHeaders mutableHeaders = (MutableMessageHeaders) messageHeaders; - MessageHeaderAccessor headerAccessor = mutableHeaders.getMessageHeaderAccessor(); + MessageHeaderAccessor headerAccessor = mutableHeaders.getAccessor(); if (requiredType.isAssignableFrom(headerAccessor.getClass())) { return (T) headerAccessor; } @@ -603,7 +598,7 @@ public class MessageHeaderAccessor { public static MessageHeaderAccessor getMutableAccessor(Message message) { if (message.getHeaders() instanceof MutableMessageHeaders) { MutableMessageHeaders mutableHeaders = (MutableMessageHeaders) message.getHeaders(); - MessageHeaderAccessor accessor = mutableHeaders.getMessageHeaderAccessor(); + MessageHeaderAccessor accessor = mutableHeaders.getAccessor(); if (accessor != null) { return (accessor.isMutable() ? accessor : accessor.createAccessor(message)); } @@ -615,38 +610,23 @@ public class MessageHeaderAccessor { @SuppressWarnings("serial") private class MutableMessageHeaders extends MessageHeaders { - private boolean immutable; - - public MutableMessageHeaders() { - this(null); - } + private boolean mutable = true; public MutableMessageHeaders(Map headers) { super(headers, MessageHeaders.ID_VALUE_NONE, -1L); } - public MessageHeaderAccessor getMessageHeaderAccessor() { - return MessageHeaderAccessor.this; - } - @Override public Map getRawHeaders() { - Assert.state(!this.immutable, "Already immutable"); + Assert.state(this.mutable, "Already immutable"); return super.getRawHeaders(); } public void setImmutable() { - this.immutable = true; - } - - public boolean isMutable() { - return !this.immutable; - } - - public void setIdAndTimestamp() { - if (!isMutable()) { + if (!this.mutable) { return; } + if (getId() == null) { IdGenerator idGenerator = (MessageHeaderAccessor.this.idGenerator != null ? MessageHeaderAccessor.this.idGenerator : MessageHeaders.getIdGenerator()); @@ -655,11 +635,27 @@ public class MessageHeaderAccessor { getRawHeaders().put(ID, id); } } + if (getTimestamp() == null) { if (MessageHeaderAccessor.this.enableTimestamp) { getRawHeaders().put(TIMESTAMP, System.currentTimeMillis()); } } + + this.mutable = false; + } + + public boolean isMutable() { + return this.mutable; + } + + public MessageHeaderAccessor getAccessor() { + return MessageHeaderAccessor.this; + } + + protected Object writeReplace() { + // Serialize as regular MessageHeaders (without MessageHeaderAccessor reference) + return new MessageHeaders(this); } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/support/MessageHeaderAccessorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/support/MessageHeaderAccessorTests.java index 82e2dce194..d9ec3f52b7 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/support/MessageHeaderAccessorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/support/MessageHeaderAccessorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,8 @@ package org.springframework.messaging.support; import java.nio.charset.StandardCharsets; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.Map; import java.util.UUID; @@ -30,8 +28,8 @@ import org.junit.rules.ExpectedException; import org.springframework.messaging.Message; import org.springframework.messaging.MessageHeaders; -import org.springframework.util.IdGenerator; import org.springframework.util.MimeTypeUtils; +import org.springframework.util.SerializationTestUtils; import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.*; @@ -166,7 +164,7 @@ public class MessageHeaderAccessorTests { headers.copyHeadersIfAbsent(null); assertEquals(1, headers.getMessageHeaders().size()); - assertEquals(new HashSet<>(Arrays.asList("id")), headers.getMessageHeaders().keySet()); + assertEquals(Collections.singleton("id"), headers.getMessageHeaders().keySet()); } @Test @@ -280,12 +278,7 @@ public class MessageHeaderAccessorTests { public void idGeneratorCustom() { final UUID id = new UUID(0L, 23L); MessageHeaderAccessor accessor = new MessageHeaderAccessor(); - accessor.setIdGenerator(new IdGenerator() { - @Override - public UUID generateId() { - return id; - } - }); + accessor.setIdGenerator(() -> id); assertSame(id, accessor.getMessageHeaders().getId()); } @@ -299,12 +292,7 @@ public class MessageHeaderAccessorTests { @Test public void idTimestampWithMutableHeaders() { MessageHeaderAccessor accessor = new MessageHeaderAccessor(); - accessor.setIdGenerator(new IdGenerator() { - @Override - public UUID generateId() { - return MessageHeaders.ID_VALUE_NONE; - } - }); + accessor.setIdGenerator(() -> MessageHeaders.ID_VALUE_NONE); accessor.setEnableTimestamp(false); accessor.setLeaveMutable(true); MessageHeaders headers = accessor.getMessageHeaders(); @@ -313,12 +301,7 @@ public class MessageHeaderAccessorTests { assertNull(headers.getTimestamp()); final UUID id = new UUID(0L, 23L); - accessor.setIdGenerator(new IdGenerator() { - @Override - public UUID generateId() { - return id; - } - }); + accessor.setIdGenerator(() -> id); accessor.setEnableTimestamp(true); accessor.setImmutable(); @@ -396,10 +379,25 @@ public class MessageHeaderAccessorTests { assertEquals("headers={contentType=text/plain} payload=" + sb + " > 80", actual); } + @Test + public void serializeMutableHeaders() throws Exception { + Map headers = new HashMap<>(); + headers.put("foo", "bar"); + Message message = new GenericMessage<>("test", headers); + MessageHeaderAccessor mutableAccessor = MessageHeaderAccessor.getMutableAccessor(message); + mutableAccessor.setContentType(MimeTypeUtils.TEXT_PLAIN); + + message = new GenericMessage<>(message.getPayload(), mutableAccessor.getMessageHeaders()); + Message output = (Message) SerializationTestUtils.serializeAndDeserialize(message); + assertEquals("test", output.getPayload()); + assertEquals("bar", output.getHeaders().get("foo")); + assertNotNull(output.getHeaders().get(MessageHeaders.CONTENT_TYPE)); + } + public static class TestMessageHeaderAccessor extends MessageHeaderAccessor { - private TestMessageHeaderAccessor() { + public TestMessageHeaderAccessor() { } private TestMessageHeaderAccessor(Message message) {