Browse Source

Add StompCodec

Previously, the broker relay's TCP client used Reactor's built in
delimited codec as part of its parsing of STOMP frames. \0 was used as
the delimiter. This worked for most STOMP frames but, crucially,
not for frames with a body that contained \0: when such a frame was
received it would be truncated.

This commit adds a custom codec that parses STOMP frames more
intelligently. It honours the content-length header allowing it to
correctly parse frames with a body that contains \0. The codec largely
delegates to two new classes: StompEncoder and StompDecoder. For
consistency, code that previously used StompMessageConverter has been
reworked to use these new encoder and decoder classes.

Issue: SPR-10818
pull/364/merge
Andy Wilkinson 11 years ago committed by Rossen Stoyanchev
parent
commit
a489c2cf38
  1. 65
      spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java
  2. 68
      spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java
  3. 157
      spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java
  4. 115
      spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java
  5. 231
      spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java
  6. 23
      spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java
  7. 3
      spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java
  8. 212
      spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java
  9. 153
      spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java
  10. 3
      spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java

65
spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandler.java

@ -17,7 +17,6 @@ @@ -17,7 +17,6 @@
package org.springframework.messaging.simp.stomp;
import java.net.InetSocketAddress;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
@ -34,7 +33,6 @@ import org.springframework.messaging.simp.SimpMessageType; @@ -34,7 +33,6 @@ import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.handler.AbstractBrokerMessageHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import reactor.core.Environment;
import reactor.core.composable.Composable;
@ -45,8 +43,6 @@ import reactor.function.Consumer; @@ -45,8 +43,6 @@ import reactor.function.Consumer;
import reactor.tcp.Reconnect;
import reactor.tcp.TcpClient;
import reactor.tcp.TcpConnection;
import reactor.tcp.encoding.DelimitedCodec;
import reactor.tcp.encoding.StandardCodecs;
import reactor.tcp.netty.NettyTcpClient;
import reactor.tcp.spec.TcpClientSpec;
import reactor.tuple.Tuple;
@ -74,11 +70,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @@ -74,11 +70,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
private String systemPasscode = "guest";
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
private Environment environment;
private TcpClient<String, String> tcpClient;
private TcpClient<Message<byte[]>, Message<byte[]>> tcpClient;
private final Map<String, RelaySession> relaySessions = new ConcurrentHashMap<String, RelaySession>();
@ -159,9 +153,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @@ -159,9 +153,9 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
@Override
protected void startInternal() {
this.environment = new Environment();
this.tcpClient = new TcpClientSpec<String, String>(NettyTcpClient.class)
this.tcpClient = new TcpClientSpec<Message<byte[]>, Message<byte[]>>(NettyTcpClient.class)
.env(this.environment)
.codec(new DelimitedCodec<String, String>((byte) 0, true, StandardCodecs.STRING_CODEC))
.codec(new StompCodec())
.connect(this.relayHost, this.relayPort)
.get();
@ -275,14 +269,14 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @@ -275,14 +269,14 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
public void connect(final Message<?> connectMessage) {
Assert.notNull(connectMessage, "connectMessage is required");
Composable<TcpConnection<String, String>> connectionComposable = openTcpConnection();
connectionComposable.consume(new Consumer<TcpConnection<String, String>>() {
Composable<TcpConnection<Message<byte[]>, Message<byte[]>>> promise = openTcpConnection();
promise.consume(new Consumer<TcpConnection<Message<byte[]>, Message<byte[]>>>() {
@Override
public void accept(TcpConnection<String, String> connection) {
public void accept(TcpConnection<Message<byte[]>, Message<byte[]>> connection) {
handleTcpConnection(connection, connectMessage);
}
});
connectionComposable.when(Throwable.class, new Consumer<Throwable>() {
promise.when(Throwable.class, new Consumer<Throwable>() {
@Override
public void accept(Throwable ex) {
relaySessions.remove(sessionId);
@ -291,29 +285,22 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @@ -291,29 +285,22 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
});
}
protected Composable<TcpConnection<String, String>> openTcpConnection() {
protected Composable<TcpConnection<Message<byte[]>, Message<byte[]>>> openTcpConnection() {
return tcpClient.open();
}
protected void handleTcpConnection(TcpConnection<String, String> tcpConn, final Message<?> connectMessage) {
protected void handleTcpConnection(TcpConnection<Message<byte[]>, Message<byte[]>> tcpConn, final Message<?> connectMessage) {
this.stompConnection.setTcpConnection(tcpConn);
tcpConn.in().consume(new Consumer<String>() {
tcpConn.in().consume(new Consumer<Message<byte[]>>() {
@Override
public void accept(String message) {
public void accept(Message<byte[]> message) {
readStompFrame(message);
}
});
forwardInternal(tcpConn, connectMessage);
}
private void readStompFrame(String stompFrame) {
// heartbeat
if (StringUtils.isEmpty(stompFrame)) {
return;
}
Message<?> message = stompMessageConverter.toMessage(stompFrame);
private void readStompFrame(Message<byte[]> message) {
if (logger.isTraceEnabled()) {
logger.trace("Reading message " + message);
}
@ -378,24 +365,24 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @@ -378,24 +365,24 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
}
private boolean forwardInternal(final Message<?> message) {
TcpConnection<String, String> tcpConnection = this.stompConnection.getReadyConnection();
TcpConnection<Message<byte[]>, Message<byte[]>> tcpConnection = this.stompConnection.getReadyConnection();
if (tcpConnection == null) {
return false;
}
return forwardInternal(tcpConnection, message);
}
private boolean forwardInternal(TcpConnection<String, String> tcpConnection, final Message<?> message) {
@SuppressWarnings("unchecked")
private boolean forwardInternal(TcpConnection<Message<byte[]>, Message<byte[]>> tcpConnection, final Message<?> message) {
Assert.isInstanceOf(byte[].class, message.getPayload(), "Message's payload must be a byte[]");
if (logger.isTraceEnabled()) {
logger.trace("Forwarding to STOMP broker, message: " + message);
}
byte[] bytes = stompMessageConverter.fromMessage(message);
String payload = new String(bytes, Charset.forName("UTF-8"));
final Deferred<Boolean, Promise<Boolean>> deferred = new DeferredPromiseSpec<Boolean>().get();
tcpConnection.send(payload, new Consumer<Boolean>() {
tcpConnection.send((Message<byte[]>)message, new Consumer<Boolean>() {
@Override
public void accept(Boolean success) {
deferred.accept(success);
@ -434,18 +421,22 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @@ -434,18 +421,22 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
private static class StompConnection {
private volatile TcpConnection<String, String> connection;
private volatile TcpConnection<Message<byte[]>, Message<byte[]>> connection;
private AtomicReference<TcpConnection<String, String>> readyConnection =
new AtomicReference<TcpConnection<String, String>>();
private AtomicReference<TcpConnection<Message<byte[]>, Message<byte[]>>> readyConnection =
new AtomicReference<TcpConnection<Message<byte[]>, Message<byte[]>>>();
public void setTcpConnection(TcpConnection<String, String> connection) {
public void setTcpConnection(TcpConnection<Message<byte[]>, Message<byte[]>> connection) {
Assert.notNull(connection, "connection must not be null");
this.connection = connection;
}
public TcpConnection<String, String> getReadyConnection() {
/**
* Return the underlying {@link TcpConnection} but only after the CONNECTED STOMP
* frame is received.
*/
public TcpConnection<Message<byte[]>, Message<byte[]>> getReadyConnection() {
return this.readyConnection.get();
}
@ -488,7 +479,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler @@ -488,7 +479,7 @@ public class StompBrokerRelayMessageHandler extends AbstractBrokerMessageHandler
}
@Override
protected Composable<TcpConnection<String, String>> openTcpConnection() {
protected Composable<TcpConnection<Message<byte[]>, Message<byte[]>>> openTcpConnection() {
return tcpClient.open(new Reconnect() {
@Override
public Tuple2<InetSocketAddress, Long> reconnect(InetSocketAddress address, int attempt) {

68
spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java

@ -0,0 +1,68 @@ @@ -0,0 +1,68 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import org.springframework.messaging.Message;
import reactor.function.Consumer;
import reactor.function.Function;
import reactor.io.Buffer;
import reactor.tcp.encoding.Codec;
/**
* A Reactor TCP {@link Codec} for sending and receiving STOMP messages
*
* @author Andy Wilkinson
* @since 4.0
*/
public class StompCodec implements Codec<Buffer, Message<byte[]>, Message<byte[]>> {
private static final StompDecoder DECODER = new StompDecoder();
private static final Function<Message<byte[]>, Buffer> ENCODER_FUNCTION = new Function<Message<byte[]>, Buffer>() {
private final StompEncoder encoder = new StompEncoder();
@Override
public Buffer apply(Message<byte[]> message) {
return Buffer.wrap(this.encoder.encode(message));
}
};
@Override
public Function<Buffer, Message<byte[]>> decoder(final Consumer<Message<byte[]>> next) {
return new Function<Buffer, Message<byte[]>>() {
@Override
public Message<byte[]> apply(Buffer buffer) {
while (buffer.remaining() > 0) {
Message<byte[]> message = DECODER.decode(buffer.byteBuffer());
if (message != null) {
next.accept(message);
}
}
return null;
}
};
}
@Override
public Function<Message<byte[]>, Buffer> encoder() {
return ENCODER_FUNCTION;
}
}

157
spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java

@ -0,0 +1,157 @@ @@ -0,0 +1,157 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* A decoder for STOMP frames
*
* @author awilkinson
* @since 4.0
*/
public class StompDecoder {
private static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
/**
* Decodes a STOMP frame in the given {@code buffer} into a {@link Message}.
*
* @param buffer The buffer to decode the frame from
* @return The decoded message
*/
public Message<byte[]> decode(ByteBuffer buffer) {
skipLeadingEol(buffer);
String command = readCommand(buffer);
if (command.length() > 0) {
MultiValueMap<String, String> headers = readHeaders(buffer);
byte[] payload = readPayload(buffer, headers);
return MessageBuilder.withPayloadAndHeaders(payload,
StompHeaderAccessor.create(StompCommand.valueOf(command), headers)).build();
}
else {
// Heartbeat
return null;
}
}
private String readCommand(ByteBuffer buffer) {
ByteArrayOutputStream command = new ByteArrayOutputStream();
while (buffer.remaining() > 0 && !isEol(buffer)) {
command.write(buffer.get());
}
return new String(command.toByteArray(), UTF8_CHARSET);
}
private MultiValueMap<String, String> readHeaders(ByteBuffer buffer) {
MultiValueMap<String, String> headers = new LinkedMultiValueMap<String, String>();
while (true) {
ByteArrayOutputStream headerStream = new ByteArrayOutputStream();
while (buffer.remaining() > 0 && !isEol(buffer)) {
headerStream.write(buffer.get());
}
if (headerStream.size() > 0) {
String header = new String(headerStream.toByteArray(), UTF8_CHARSET);
int colonIndex = header.indexOf(':');
if (colonIndex <= 0 || colonIndex == header.length() - 1) {
throw new StompConversionException(
"Illegal header: '" + header + "'. A header must be of the form <name>:<value");
}
else {
String headerName = unescape(header.substring(0, colonIndex));
String headerValue = unescape(header.substring(colonIndex + 1));
headers.add(headerName, headerValue);
}
}
else {
break;
}
}
return headers;
}
private String unescape(String input) {
return input.replaceAll("\\\\n", "\n")
.replaceAll("\\\\r", "\r")
.replaceAll("\\\\c", ":")
.replaceAll("\\\\\\\\", "\\\\");
}
private byte[] readPayload(ByteBuffer buffer, MultiValueMap<String, String> headers) {
String contentLengthString = headers.getFirst("content-length");
if (contentLengthString != null) {
int contentLength = Integer.valueOf(contentLengthString);
byte[] payload = new byte[contentLength];
buffer.get(payload);
if (buffer.remaining() < 1 || buffer.get() != 0) {
throw new StompConversionException("Frame must be terminated with a null octect");
}
return payload;
}
else {
ByteArrayOutputStream payload = new ByteArrayOutputStream();
while (buffer.remaining() > 0) {
byte b = buffer.get();
if (b == 0) {
return payload.toByteArray();
}
else {
payload.write(b);
}
}
}
throw new StompConversionException("Frame must be terminated with a null octect");
}
private void skipLeadingEol(ByteBuffer buffer) {
while (true) {
if (!isEol(buffer)) {
break;
}
}
}
private boolean isEol(ByteBuffer buffer) {
if (buffer.remaining() > 0) {
byte b = buffer.get();
if (b == '\n') {
return true;
}
else if (b == '\r') {
if (buffer.remaining() > 0 && buffer.get() == '\n') {
return true;
}
else {
throw new StompConversionException("'\\r' must be followed by '\\n'");
}
}
buffer.position(buffer.position() - 1);
}
return false;
}
}

115
spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java

@ -0,0 +1,115 @@ @@ -0,0 +1,115 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map.Entry;
import org.springframework.messaging.Message;
/**
* An encoder for STOMP frames
*
* @author Andy Wilkinson
* @since 4.0
*/
public final class StompEncoder {
private static final byte LF = '\n';
private static final byte COLON = ':';
private static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
/**
* Encodes the given STOMP {@code message} into a {@code byte[]}
*
* @param message The message to encode
*
* @return The encoded message
*/
public byte[] encode(Message<byte[]> message) {
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream output = new DataOutputStream(baos);
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
writeCommand(headers, output);
writeHeaders(headers, message, output);
output.write(LF);
writeBody(message, output);
output.write((byte)0);
return baos.toByteArray();
}
catch (IOException e) {
throw new StompConversionException("Failed to encode STOMP frame", e);
}
}
private void writeCommand(StompHeaderAccessor headers, DataOutputStream output) throws IOException {
output.write(headers.getCommand().toString().getBytes(UTF8_CHARSET));
output.write(LF);
}
private void writeHeaders(StompHeaderAccessor headers, Message<byte[]> message, DataOutputStream output)
throws IOException {
for (Entry<String, List<String>> entry : headers.toStompHeaderMap().entrySet()) {
byte[] key = getUtf8BytesEscapingIfNecessary(entry.getKey(), headers);
for (String value : entry.getValue()) {
output.write(key);
output.write(COLON);
output.write(getUtf8BytesEscapingIfNecessary(value, headers));
output.write(LF);
}
}
if (headers.getCommand() == StompCommand.SEND ||
headers.getCommand() == StompCommand.MESSAGE ||
headers.getCommand() == StompCommand.ERROR) {
output.write("content-length:".getBytes(UTF8_CHARSET));
output.write(Integer.toString(message.getPayload().length).getBytes(UTF8_CHARSET));
output.write(LF);
}
}
private void writeBody(Message<byte[]> message, DataOutputStream output) throws IOException {
output.write(message.getPayload());
}
private byte[] getUtf8BytesEscapingIfNecessary(String input, StompHeaderAccessor headers) {
if (headers.getCommand() != StompCommand.CONNECT && headers.getCommand() != StompCommand.CONNECTED) {
return escape(input).getBytes(UTF8_CHARSET);
}
else {
return input.getBytes(UTF8_CHARSET);
}
}
private String escape(String input) {
return input.replaceAll("\\\\", "\\\\\\\\")
.replaceAll(":", "\\\\c")
.replaceAll("\n", "\\\\n")
.replaceAll("\r", "\\\\r");
}
}

231
spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompMessageConverter.java

@ -1,231 +0,0 @@ @@ -1,231 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map.Entry;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* @author Gary Russell
* @author Rossen Stoyanchev
* @since 4.0
*/
public class StompMessageConverter {
private static final Charset STOMP_CHARSET = Charset.forName("UTF-8");
public static final byte LF = 0x0a;
public static final byte CR = 0x0d;
private static final byte COLON = ':';
/**
* @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String.
*/
public Message<?> toMessage(Object stompContent) {
byte[] byteContent = null;
if (stompContent instanceof String) {
byteContent = ((String) stompContent).getBytes(STOMP_CHARSET);
}
else if (stompContent instanceof byte[]){
byteContent = (byte[]) stompContent;
}
else {
throw new IllegalArgumentException(
"stompContent is neither String nor byte[]: " + stompContent.getClass());
}
int totalLength = byteContent.length;
if (byteContent[totalLength-1] == 0) {
totalLength--;
}
int payloadIndex = findIndexOfPayload(byteContent);
if (payloadIndex == 0) {
throw new StompConversionException("No command found");
}
String headerContent = new String(byteContent, 0, payloadIndex, STOMP_CHARSET);
Parser parser = new Parser(headerContent);
StompCommand command = StompCommand.valueOf(parser.nextToken(LF).trim());
Assert.notNull(command, "No command found");
MultiValueMap<String, String> headers = new LinkedMultiValueMap<String, String>();
while (parser.hasNext()) {
String header = parser.nextToken(COLON);
if (header != null) {
if (parser.hasNext()) {
String value = parser.nextToken(LF);
headers.add(header, value);
}
else {
throw new StompConversionException("Parse exception for " + headerContent);
}
}
}
byte[] payload = new byte[totalLength - payloadIndex];
System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers);
return MessageBuilder.withPayloadAndHeaders(payload, stompHeaders).build();
}
private int findIndexOfPayload(byte[] bytes) {
int i;
// ignore any leading EOL from the previous message
for (i = 0; i < bytes.length; i++) {
if (bytes[i] != '\n' && bytes[i] != '\r') {
break;
}
bytes[i] = ' ';
}
int index = 0;
for (; i < bytes.length - 1; i++) {
if (bytes[i] == LF && bytes[i+1] == LF) {
index = i + 2;
break;
}
if ((i < (bytes.length - 3)) &&
(bytes[i] == CR && bytes[i+1] == LF && bytes[i+2] == CR && bytes[i+3] == LF)) {
index = i + 4;
break;
}
}
if (i >= bytes.length) {
throw new StompConversionException("No end of headers found");
}
return index;
}
public byte[] fromMessage(Message<?> message) {
byte[] payload;
if (message.getPayload() instanceof byte[]) {
payload = (byte[]) message.getPayload();
}
else {
throw new IllegalArgumentException(
"stompContent is not byte[]: " + message.getPayload().getClass());
}
ByteArrayOutputStream out = new ByteArrayOutputStream();
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
try {
out.write(stompHeaders.getCommand().toString().getBytes("UTF-8"));
out.write(LF);
for (Entry<String, List<String>> entry : stompHeaders.toStompHeaderMap().entrySet()) {
String key = entry.getKey();
key = replaceAllOutbound(key);
for (String value : entry.getValue()) {
out.write(key.getBytes("UTF-8"));
out.write(COLON);
value = replaceAllOutbound(value);
out.write(value.getBytes("UTF-8"));
out.write(LF);
}
}
out.write(LF);
out.write(payload);
out.write(0);
return out.toByteArray();
}
catch (IOException e) {
throw new StompConversionException("Failed to serialize " + message, e);
}
}
private String replaceAllOutbound(String key) {
return key.replaceAll("\\\\", "\\\\")
.replaceAll(":", "\\\\c")
.replaceAll("\n", "\\\\n")
.replaceAll("\r", "\\\\r");
}
private class Parser {
private final String content;
private int offset;
public Parser(String content) {
this.content = content;
}
public boolean hasNext() {
return this.offset < this.content.length();
}
public String nextToken(byte delimiter) {
if (this.offset >= this.content.length()) {
return null;
}
int delimAt = this.content.indexOf(delimiter, this.offset);
if (delimAt == -1) {
if (this.offset == this.content.length() - 1 && delimiter == COLON &&
this.content.charAt(this.offset) == LF) {
this.offset++;
return null;
}
else if (this.offset == this.content.length() - 2 && delimiter == COLON &&
this.content.charAt(this.offset) == CR &&
this.content.charAt(this.offset + 1) == LF) {
this.offset += 2;
return null;
}
else {
throw new StompConversionException("No delimiter found at offset " + offset + " in " + this.content);
}
}
int escapeAt = this.content.indexOf('\\', this.offset);
String token = this.content.substring(this.offset, delimAt + 1);
this.offset += token.length();
if (escapeAt >= 0 && escapeAt < delimAt) {
char escaped = this.content.charAt(escapeAt + 1);
if (escaped == 'n' || escaped == 'c' || escaped == '\\') {
token = token.replaceAll("\\\\n", "\n")
.replaceAll("\\\\r", "\r")
.replaceAll("\\\\c", ":")
.replaceAll("\\\\\\\\", "\\\\");
}
else {
throw new StompConversionException("Invalid escape sequence \\" + escaped);
}
}
int length = token.length();
if (delimiter == LF && length > 1 && token.charAt(length - 2) == CR) {
return token.substring(0, length - 2);
}
else {
return token.substring(0, length - 1);
}
}
}
}

23
spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompProtocolHandler.java

@ -17,6 +17,7 @@ @@ -17,6 +17,7 @@
package org.springframework.messaging.simp.stomp;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.security.Principal;
import java.util.Arrays;
@ -63,7 +64,9 @@ public class StompProtocolHandler implements SubProtocolHandler { @@ -63,7 +64,9 @@ public class StompProtocolHandler implements SubProtocolHandler {
private final Log logger = LogFactory.getLog(StompProtocolHandler.class);
private final StompMessageConverter stompMessageConverter = new StompMessageConverter();
private final StompDecoder stompDecoder = new StompDecoder();
private final StompEncoder stompEncoder = new StompEncoder();
private MutableUserQueueSuffixResolver queueSuffixResolver = new SimpleUserQueueSuffixResolver();
@ -98,7 +101,8 @@ public class StompProtocolHandler implements SubProtocolHandler { @@ -98,7 +101,8 @@ public class StompProtocolHandler implements SubProtocolHandler {
try {
Assert.isInstanceOf(TextMessage.class, webSocketMessage);
String payload = ((TextMessage)webSocketMessage).getPayload();
message = this.stompMessageConverter.toMessage(payload);
ByteBuffer byteBuffer = ByteBuffer.wrap(payload.getBytes(Charset.forName("UTF-8")));
message = this.stompDecoder.decode(byteBuffer);
}
catch (Throwable error) {
logger.error("Failed to parse STOMP frame, WebSocket message payload: ", error);
@ -133,6 +137,7 @@ public class StompProtocolHandler implements SubProtocolHandler { @@ -133,6 +137,7 @@ public class StompProtocolHandler implements SubProtocolHandler {
/**
* Handle STOMP messages going back out to WebSocket clients.
*/
@SuppressWarnings("unchecked")
@Override
public void handleMessageToClient(WebSocketSession session, Message<?> message) {
@ -156,7 +161,7 @@ public class StompProtocolHandler implements SubProtocolHandler { @@ -156,7 +161,7 @@ public class StompProtocolHandler implements SubProtocolHandler {
try {
message = MessageBuilder.withPayloadAndHeaders(message.getPayload(), headers).build();
byte[] bytes = this.stompMessageConverter.fromMessage(message);
byte[] bytes = this.stompEncoder.encode((Message<byte[]>)message);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
catch (Throwable t) {
@ -204,19 +209,19 @@ public class StompProtocolHandler implements SubProtocolHandler { @@ -204,19 +209,19 @@ public class StompProtocolHandler implements SubProtocolHandler {
}
}
Message<?> connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build();
byte[] bytes = this.stompMessageConverter.fromMessage(connectedMessage);
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
Message<byte[]> connectedMessage = MessageBuilder.withPayloadAndHeaders(new byte[0], connectedHeaders).build();
String payload = new String(this.stompEncoder.encode(connectedMessage), Charset.forName("UTF-8"));
session.sendMessage(new TextMessage(payload));
}
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);
headers.setMessage(error.getMessage());
Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
byte[] bytes = this.stompMessageConverter.fromMessage(message);
Message<byte[]> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
String payload = new String(this.stompEncoder.encode(message), Charset.forName("UTF-8"));
try {
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
session.sendMessage(new TextMessage(payload));
}
catch (Throwable t) {
// ignore

3
spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompBrokerRelayMessageHandlerIntegrationTests.java

@ -253,7 +253,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { @@ -253,7 +253,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
}
public void awaitAndAssert() throws InterruptedException {
boolean result = this.latch.await(5000, TimeUnit.MILLISECONDS);
boolean result = this.latch.await(10000, TimeUnit.MILLISECONDS);
assertTrue(getAsString(), result && this.unexpected.isEmpty());
}
@ -356,6 +356,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests { @@ -356,6 +356,7 @@ public class StompBrokerRelayMessageHandlerIntegrationTests {
public static MessageExchangeBuilder connect(String sessionId) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setSessionId(sessionId);
headers.setAcceptVersion("1.1,1.2");
Message<?> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
return new MessageExchangeBuilder(message);
}

212
spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java

@ -0,0 +1,212 @@ @@ -0,0 +1,212 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import reactor.function.Consumer;
import reactor.function.Function;
import reactor.io.Buffer;
import static org.junit.Assert.*;
/**
*
* @author awilkinson
*/
public class StompCodecTests {
private final ArgumentCapturingConsumer<Message<byte[]>> consumer = new ArgumentCapturingConsumer<Message<byte[]>>();
private final Function<Buffer, Message<byte[]>> decoder = new StompCodec().decoder(consumer);
@Test
public void decodeFrameWithCrLfEols() {
Message<byte[]> frame = decode("DISCONNECT\r\n\r\n\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.DISCONNECT, headers.getCommand());
assertEquals(0, headers.toStompHeaderMap().size());
assertEquals(0, frame.getPayload().length);
}
@Test
public void decodeFrameWithNoHeadersAndNoBody() {
Message<byte[]> frame = decode("DISCONNECT\n\n\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.DISCONNECT, headers.getCommand());
assertEquals(0, headers.toStompHeaderMap().size());
assertEquals(0, frame.getPayload().length);
}
@Test
public void decodeFrameWithNoBody() {
String accept = "accept-version:1.1\n";
String host = "host:github.org\n";
Message<byte[]> frame = decode("CONNECT\n" + accept + host + "\n\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.CONNECT, headers.getCommand());
assertEquals(2, headers.toStompHeaderMap().size());
assertEquals("1.1", headers.getFirstNativeHeader("accept-version"));
assertEquals("github.org", headers.getHost());
assertEquals(0, frame.getPayload().length);
}
@Test
public void decodeFrame() throws UnsupportedEncodingException {
Message<byte[]> frame = decode("SEND\ndestination:test\n\nThe body of the message\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.SEND, headers.getCommand());
assertEquals(1, headers.toStompHeaderMap().size());
assertEquals("test", headers.getDestination());
String bodyText = new String(frame.getPayload());
assertEquals("The body of the message", bodyText);
}
@Test
public void decodeFrameWithContentLength() {
Message<byte[]> frame = decode("SEND\ncontent-length:23\n\nThe body of the message\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.SEND, headers.getCommand());
assertEquals(1, headers.toStompHeaderMap().size());
assertEquals(Integer.valueOf(23), headers.getContentLength());
String bodyText = new String(frame.getPayload());
assertEquals("The body of the message", bodyText);
}
@Test
public void decodeFrameWithNullOctectsInTheBody() {
Message<byte[]> frame = decode("SEND\ncontent-length:23\n\nThe b\0dy \0f the message\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.SEND, headers.getCommand());
assertEquals(1, headers.toStompHeaderMap().size());
assertEquals(Integer.valueOf(23), headers.getContentLength());
String bodyText = new String(frame.getPayload());
assertEquals("The b\0dy \0f the message", bodyText);
}
@Test
public void decodeFrameWithEscapedHeaders() {
Message<byte[]> frame = decode("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0");
StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame);
assertEquals(StompCommand.DISCONNECT, headers.getCommand());
assertEquals(1, headers.toStompHeaderMap().size());
assertEquals("alpha:bravo\r\n\\", headers.getFirstNativeHeader("a:\r\n\\b"));
}
@Test
public void decodeMultipleFramesFromSameBuffer() {
String frame1 = "SEND\ndestination:test\n\nThe body of the message\0";
String frame2 = "DISCONNECT\n\n\0";
Buffer buffer = Buffer.wrap(frame1 + frame2);
final List<Message<byte[]>> messages = new ArrayList<Message<byte[]>>();
new StompCodec().decoder(new Consumer<Message<byte[]>>() {
@Override
public void accept(Message<byte[]> message) {
messages.add(message);
}
}).apply(buffer);
assertEquals(2, messages.size());
assertEquals(StompCommand.SEND, StompHeaderAccessor.wrap(messages.get(0)).getCommand());
assertEquals(StompCommand.DISCONNECT, StompHeaderAccessor.wrap(messages.get(1)).getCommand());
}
@Test
public void encodeFrameWithNoHeadersAndNoBody() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
Message<byte[]> frame = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
assertEquals("DISCONNECT\n\n\0", new StompCodec().encoder().apply(frame).asString());
}
@Test
public void encodeFrameWithHeaders() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
headers.setAcceptVersion("1.2");
headers.setHost("github.org");
Message<byte[]> frame = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
String frameString = new StompCodec().encoder().apply(frame).asString();
assertTrue(frameString.equals("CONNECT\naccept-version:1.2\nhost:github.org\n\n\0") ||
frameString.equals("CONNECT\nhost:github.org\naccept-version:1.2\n\n\0"));
}
@Test
public void encodeFrameWithHeadersThatShouldBeEscaped() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT);
headers.addNativeHeader("a:\r\n\\b", "alpha:bravo\r\n\\");
Message<byte[]> frame = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
assertEquals("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0", new StompCodec().encoder().apply(frame).asString());
}
@Test
public void encodeFrameWithHeadersBody() {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.addNativeHeader("a", "alpha");
Message<byte[]> frame = MessageBuilder.withPayloadAndHeaders("Message body".getBytes(), headers).build();
assertEquals("SEND\na:alpha\ncontent-length:12\n\nMessage body\0", new StompCodec().encoder().apply(frame).asString());
}
private Message<byte[]> decode(String stompFrame) {
this.decoder.apply(Buffer.wrap(stompFrame));
return consumer.arguments.get(0);
}
private static final class ArgumentCapturingConsumer<T> implements Consumer<T> {
private final List<T> arguments = new ArrayList<T>();
@Override
public void accept(T t) {
arguments.add(t);
}
}
}

153
spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompMessageConverterTests.java

@ -1,153 +0,0 @@ @@ -1,153 +0,0 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.stomp;
import java.util.Collections;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.web.socket.TextMessage;
import static org.junit.Assert.*;
/**
* @author Gary Russell
* @author Rossen Stoyanchev
*/
public class StompMessageConverterTests {
private StompMessageConverter converter;
@Before
public void setup() {
this.converter = new StompMessageConverter();
}
@Test
public void connectFrame() throws Exception {
String accept = "accept-version:1.1";
String host = "host:github.org";
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT)
.headers(accept, host).build();
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(textMessage.getPayload());
assertEquals(0, message.getPayload().length);
MessageHeaders headers = message.getHeaders();
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
Map<String, Object> map = stompHeaders.toMap();
assertEquals(5, map.size());
assertNotNull(stompHeaders.getId());
assertNotNull(stompHeaders.getTimestamp());
assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getCommand());
assertNotNull(map.get(SimpMessageHeaderAccessor.NATIVE_HEADERS));
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getHost());
assertEquals(SimpMessageType.CONNECT, stompHeaders.getMessageType());
assertEquals(StompCommand.CONNECT, stompHeaders.getCommand());
assertNotNull(headers.get(MessageHeaders.ID));
assertNotNull(headers.get(MessageHeaders.TIMESTAMP));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
assertEquals("CONNECT\n", convertedBack.substring(0,8));
assertTrue(convertedBack.contains(accept));
assertTrue(convertedBack.contains(host));
}
@Test
public void connectWithEscapes() throws Exception {
String accept = "accept-version:1.1";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org";
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT)
.headers(accept, host).build();
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(textMessage.getPayload());
assertEquals(0, message.getPayload().length);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.toNativeHeaderMap().get("ho:\ns\rt").get(0));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
assertEquals("CONNECT\n", convertedBack.substring(0,8));
assertTrue(convertedBack.contains(accept));
assertTrue(convertedBack.contains(host));
}
@Test
public void connectCR12() throws Exception {
String accept = "accept-version:1.2\n";
String host = "host:github.org\n";
String test = "CONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(test.getBytes("UTF-8"));
assertEquals(0, message.getPayload().length);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
assertEquals(Collections.singleton("1.2"), stompHeaders.getAcceptVersion());
assertEquals("github.org", stompHeaders.getHost());
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
assertEquals("CONNECT\n", convertedBack.substring(0,8));
assertTrue(convertedBack.contains(accept));
assertTrue(convertedBack.contains(host));
}
@Test
public void connectWithEscapesAndCR12() throws Exception {
String accept = "accept-version:1.1\n";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String test = "\n\n\nCONNECT\r\n" + accept.replaceAll("\n", "\r\n") + host.replaceAll("\n", "\r\n") + "\r\n";
@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(test.getBytes("UTF-8"));
assertEquals(0, message.getPayload().length);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.wrap(message);
assertEquals(Collections.singleton("1.1"), stompHeaders.getAcceptVersion());
assertEquals("st\nomp.gi:thu\\b.org", stompHeaders.toNativeHeaderMap().get("ho:\ns\rt").get(0));
String convertedBack = new String(this.converter.fromMessage(message), "UTF-8");
assertEquals("CONNECT\n", convertedBack.substring(0,8));
assertTrue(convertedBack.contains(accept));
assertTrue(convertedBack.contains(host));
}
}

3
spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompProtocolHandlerTests.java

@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
package org.springframework.messaging.simp.stomp;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashSet;
@ -84,7 +85,7 @@ public class StompProtocolHandlerTests { @@ -84,7 +85,7 @@ public class StompProtocolHandlerTests {
assertEquals(1, this.session.getSentMessages().size());
textMessage = (TextMessage) this.session.getSentMessages().get(0);
Message<?> message = new StompMessageConverter().toMessage(textMessage.getPayload());
Message<?> message = new StompDecoder().decode(ByteBuffer.wrap(textMessage.getPayload().getBytes()));
StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message);
assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand());

Loading…
Cancel
Save