diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java index bef1726ae5..a955a2dc42 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompCodec.java @@ -53,6 +53,8 @@ public class StompCodec implements Codec, Message message = DECODER.decode(buffer.byteBuffer()); if (message != null) { next.accept(message); + } else { + break; } } return null; diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java index 7fd310ca1f..4265ade4d0 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java @@ -52,7 +52,9 @@ public class StompDecoder { public Message decode(ByteBuffer buffer) { skipLeadingEol(buffer); - Message decodedMessage; + Message decodedMessage = null; + + buffer.mark(); String command = readCommand(buffer); @@ -60,18 +62,25 @@ public class StompDecoder { MultiValueMap headers = readHeaders(buffer); byte[] payload = readPayload(buffer, headers); - StompCommand stompCommand = StompCommand.valueOf(command); - if ((payload.length > 0) && (!stompCommand.isBodyAllowed())) { - throw new StompConversionException(stompCommand + - " isn't allowed to have a body but has payload length=" + payload.length + - ", headers=" + headers); - } + if (payload != null) { + StompCommand stompCommand = StompCommand.valueOf(command); + if ((payload.length > 0) && (!stompCommand.isBodyAllowed())) { + throw new StompConversionException(stompCommand + + " isn't allowed to have a body but has payload length=" + payload.length + + ", headers=" + headers); + } - decodedMessage = MessageBuilder.withPayload(payload) - .setHeaders(StompHeaderAccessor.create(stompCommand, headers)).build(); + decodedMessage = MessageBuilder.withPayload(payload) + .setHeaders(StompHeaderAccessor.create(stompCommand, headers)).build(); - if (logger.isDebugEnabled()) { - logger.debug("Decoded " + decodedMessage); + if (logger.isDebugEnabled()) { + logger.debug("Decoded " + decodedMessage); + } + } else { + if (logger.isDebugEnabled()) { + logger.debug("Received incomplete frame. Resetting buffer"); + } + buffer.reset(); } } else { @@ -105,8 +114,10 @@ public class StompDecoder { String header = new String(headerStream.toByteArray(), UTF8_CHARSET); int colonIndex = header.indexOf(':'); if ((colonIndex <= 0) || (colonIndex == header.length() - 1)) { - throw new StompConversionException( - "Illegal header: '" + header + "'. A header must be of the form : 0) { + throw new StompConversionException( + "Illegal header: '" + header + "'. A header must be of the form :"); + } } else { String headerName = unescape(header.substring(0, colonIndex)); @@ -133,10 +144,15 @@ public class StompDecoder { if (contentLengthString != null) { int contentLength = Integer.valueOf(contentLengthString); byte[] payload = new byte[contentLength]; - buffer.get(payload); - if (buffer.remaining() < 1 || buffer.get() != 0) { - throw new StompConversionException("Frame must be terminated with a null octect"); + if (buffer.remaining() > contentLength) { + buffer.get(payload); + if (buffer.get() != 0) { + throw new StompConversionException("Frame must be terminated with a null octet"); + } + } else { + return null; } + return payload; } else { @@ -151,7 +167,7 @@ public class StompDecoder { } } } - throw new StompConversionException("Frame must be terminated with a null octect"); + return null; } private void skipLeadingEol(ByteBuffer buffer) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java index 5cd5e995c1..901e4e2b14 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompCodecTests.java @@ -158,6 +158,30 @@ public class StompCodecTests { assertEquals(StompCommand.DISCONNECT, StompHeaderAccessor.wrap(messages.get(1)).getCommand()); } + @Test + public void decodeFrameWithIncompleteHeader() { + assertIncompleteDecode("SEND\ndestination"); + assertIncompleteDecode("SEND\ndestination:"); + assertIncompleteDecode("SEND\ndestination:test"); + } + + @Test + public void decodeFrameWithoutNullOctetTerminator() { + assertIncompleteDecode("SEND\ndestination:test\n"); + assertIncompleteDecode("SEND\ndestination:test\n\n"); + assertIncompleteDecode("SEND\ndestination:test\n\nThe body"); + } + + @Test + public void decodeFrameWithInsufficientContent() { + assertIncompleteDecode("SEND\ncontent-length:23\n\nThe body of the mess"); + } + + @Test(expected=StompConversionException.class) + public void decodeFrameWithIncorrectTerminator() { + decode("SEND\ncontent-length:23\n\nThe body of the message*"); + } + @Test public void decodeHeartbeat() { String frame = "\n"; @@ -219,11 +243,28 @@ public class StompCodecTests { assertEquals("SEND\na:alpha\ncontent-length:12\n\nMessage body\0", new StompCodec().encoder().apply(frame).asString()); } + private void assertIncompleteDecode(String partialFrame) { + Buffer buffer = Buffer.wrap(partialFrame); + assertNull(decode(buffer)); + assertEquals(0, buffer.position()); + } + private Message decode(String stompFrame) { - this.decoder.apply(Buffer.wrap(stompFrame)); - return consumer.arguments.get(0); + Buffer buffer = Buffer.wrap(stompFrame); + return decode(buffer); } + private Message decode(Buffer buffer) { + this.decoder.apply(buffer); + if (consumer.arguments.isEmpty()) { + return null; + } else { + return consumer.arguments.get(0); + } + } + + + private static final class ArgumentCapturingConsumer implements Consumer { private final List arguments = new ArrayList();