@ -18,6 +18,10 @@ package org.springframework.web.socket.messaging;
@@ -18,6 +18,10 @@ package org.springframework.web.socket.messaging;
import java.lang.annotation.Retention ;
import java.lang.annotation.RetentionPolicy ;
import java.nio.ByteBuffer ;
import java.nio.charset.StandardCharsets ;
import java.util.ArrayList ;
import java.util.Arrays ;
import java.util.List ;
import java.util.concurrent.CopyOnWriteArrayList ;
import java.util.concurrent.CountDownLatch ;
@ -28,25 +32,22 @@ import org.junit.jupiter.api.TestInfo;
@@ -28,25 +32,22 @@ import org.junit.jupiter.api.TestInfo;
import org.springframework.beans.factory.annotation.Autowired ;
import org.springframework.context.annotation.Bean ;
import org.springframework.context.annotation.ComponentScan ;
import org.springframework.context.annotation.Configuration ;
import org.springframework.context.annotation.Scope ;
import org.springframework.context.annotation.ScopedProxyMode ;
import org.springframework.core.task.TaskExecutor ;
import org.springframework.messaging.handler.annotation.MessageExceptionHandler ;
import org.springframework.messaging.handler.annotation.MessageMapping ;
import org.springframework.messaging.simp.annotation.SendToUser ;
import org.springframework.messaging.simp.annotation.SubscribeMapping ;
import org.springframework.messaging.simp.config.MessageBrokerRegistry ;
import org.springframework.messaging.simp.stomp.StompCommand ;
import org.springframework.messaging.support.AbstractSubscribableChannel ;
import org.springframework.messaging.support.ExecutorSubscribableChannel ;
import org.springframework.messaging.simp.stomp.StompDecoder ;
import org.springframework.stereotype.Controller ;
import org.springframework.web.socket.AbstractWebSocketIntegrationTests ;
import org.springframework.web.socket.TextMessage ;
import org.springframework.web.socket.WebSocketSession ;
import org.springframework.web.socket.WebSocketTestServer ;
import org.springframework.web.socket.client.WebSocketClient ;
import org.springframework.web.socket.config.annotation.DelegatingWebSocketMessageBrokerConfiguration ;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker ;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry ;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer ;
import org.springframework.web.socket.handler.TextWebSocketHandler ;
@ -69,7 +70,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@@ -69,7 +70,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@Override
protected Class < ? > [ ] getAnnotatedConfigClasses ( ) {
return new Class < ? > [ ] { TestMessageBrokerConfiguration . class , TestMessageBrokerConfigur er . class } ;
return new Class < ? > [ ] { TestMessageBrokerConfigurer . class } ;
}
@ -100,7 +101,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@@ -100,7 +101,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
TextMessage m2 = create ( StompCommand . SEND )
. headers ( "destination:/app/increment" ) . body ( "5" ) . build ( ) ;
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler ( 2 , m0 , m1 , m2 ) ;
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler ( 1 , m0 , m1 , m2 ) ;
try ( WebSocketSession session = execute ( clientHandler , "/ws" ) . get ( ) ) {
assertThat ( session ) . isNotNull ( ) ;
@ -121,17 +122,49 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@@ -121,17 +122,49 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
TextMessage m1 = create ( StompCommand . SUBSCRIBE ) . headers ( "id:subs1" , destination , selector ) . build ( ) ;
TextMessage m2 = create ( StompCommand . SEND ) . headers ( destination , "foo:bar" ) . body ( "5" ) . build ( ) ;
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler ( 2 , m0 , m1 , m2 ) ;
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler ( 1 , m0 , m1 , m2 ) ;
try ( WebSocketSession session = execute ( clientHandler , "/ws" ) . get ( ) ) {
assertThat ( session ) . isNotNull ( ) ;
assertThat ( clientHandler . latch . await ( TIMEOUT , TimeUnit . SECONDS ) ) . isTrue ( ) ;
String payload = clientHandler . actual . get ( 1 ) . getPayload ( ) ;
String payload = clientHandler . actual . get ( 0 ) . getPayload ( ) ;
assertThat ( payload ) . as ( "Expected STOMP MESSAGE, got " + payload ) . startsWith ( "MESSAGE\n" ) ;
}
}
@ParameterizedWebSocketTest // gh-21798
void sendMessageToBrokerAndReceiveInOrder (
WebSocketTestServer server , WebSocketClient webSocketClient , TestInfo testInfo ) throws Exception {
super . setup ( server , webSocketClient , testInfo ) ;
String destination = "destination:/topic/foo" ;
List < TextMessage > messages = new ArrayList < > ( ) ;
messages . add ( create ( StompCommand . CONNECT ) . headers ( "accept-version:1.1" ) . build ( ) ) ;
messages . add ( create ( StompCommand . SUBSCRIBE ) . headers ( "id:subs1" , destination ) . build ( ) ) ;
int count = 1000 ;
for ( int i = 0 ; i < count ; i + + ) {
messages . add ( create ( StompCommand . SEND ) . headers ( destination ) . body ( String . valueOf ( i ) ) . build ( ) ) ;
}
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler ( count , messages ) ;
try ( WebSocketSession session = execute ( clientHandler , "/ws" ) . get ( ) ) {
assertThat ( session ) . isNotNull ( ) ;
assertThat ( clientHandler . latch . await ( TIMEOUT , TimeUnit . SECONDS ) ) . isTrue ( ) ;
for ( int i = 0 ; i < count ; i + + ) {
TextMessage message = clientHandler . actual . get ( i ) ;
ByteBuffer buffer = ByteBuffer . wrap ( message . asBytes ( ) ) ;
byte [ ] bytes = new StompDecoder ( ) . decode ( buffer ) . get ( 0 ) . getPayload ( ) ;
assertThat ( new String ( bytes , StandardCharsets . UTF_8 ) ) . isEqualTo ( String . valueOf ( i ) ) ;
}
}
}
@ParameterizedWebSocketTest // SPR-11648
void sendSubscribeToControllerAndReceiveReply (
WebSocketTestServer server , WebSocketClient webSocketClient , TestInfo testInfo ) throws Exception {
@ -142,12 +175,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@@ -142,12 +175,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
String destHeader = "destination:/app/number" ;
TextMessage m1 = create ( StompCommand . SUBSCRIBE ) . headers ( "id:subs1" , destHeader ) . build ( ) ;
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler ( 2 , m0 , m1 ) ;
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler ( 1 , m0 , m1 ) ;
try ( WebSocketSession session = execute ( clientHandler , "/ws" ) . get ( ) ) {
assertThat ( session ) . isNotNull ( ) ;
assertThat ( clientHandler . latch . await ( TIMEOUT , TimeUnit . SECONDS ) ) . isTrue ( ) ;
String payload = clientHandler . actual . get ( 1 ) . getPayload ( ) ;
String payload = clientHandler . actual . get ( 0 ) . getPayload ( ) ;
assertThat ( payload ) . as ( "Expected STOMP destination=/app/number, got " + payload ) . contains ( destHeader ) ;
assertThat ( payload ) . as ( "Expected STOMP Payload=42, got " + payload ) . contains ( "42" ) ;
}
@ -164,12 +197,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@@ -164,12 +197,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
TextMessage m1 = create ( StompCommand . SUBSCRIBE ) . headers ( "id:subs1" , destHeader ) . build ( ) ;
TextMessage m2 = create ( StompCommand . SEND ) . headers ( "destination:/app/exception" ) . build ( ) ;
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler ( 2 , m0 , m1 , m2 ) ;
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler ( 1 , m0 , m1 , m2 ) ;
try ( WebSocketSession session = execute ( clientHandler , "/ws" ) . get ( ) ) {
assertThat ( session ) . isNotNull ( ) ;
assertThat ( clientHandler . latch . await ( TIMEOUT , TimeUnit . SECONDS ) ) . isTrue ( ) ;
String payload = clientHandler . actual . get ( 1 ) . getPayload ( ) ;
String payload = clientHandler . actual . get ( 0 ) . getPayload ( ) ;
assertThat ( payload ) . startsWith ( "MESSAGE\n" ) ;
assertThat ( payload ) . contains ( "destination:/user/queue/error\n" ) ;
assertThat ( payload ) . endsWith ( "Got error: Bad input\0" ) ;
@ -188,12 +221,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@@ -188,12 +221,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
TextMessage m2 = create ( StompCommand . SEND )
. headers ( "destination:/app/scopedBeanValue" ) . build ( ) ;
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler ( 2 , m0 , m1 , m2 ) ;
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler ( 1 , m0 , m1 , m2 ) ;
try ( WebSocketSession session = execute ( clientHandler , "/ws" ) . get ( ) ) {
assertThat ( session ) . isNotNull ( ) ;
assertThat ( clientHandler . latch . await ( TIMEOUT , TimeUnit . SECONDS ) ) . isTrue ( ) ;
String payload = clientHandler . actual . get ( 1 ) . getPayload ( ) ;
String payload = clientHandler . actual . get ( 0 ) . getPayload ( ) ;
assertThat ( payload ) . startsWith ( "MESSAGE\n" ) ;
assertThat ( payload ) . contains ( "destination:/topic/scopedBeanValue\n" ) ;
assertThat ( payload ) . endsWith ( "55\0" ) ;
@ -285,18 +318,19 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@@ -285,18 +318,19 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
private static class TestClientWebSocketHandler extends TextWebSocketHandler {
private final TextMessage [ ] messagesToSend ;
private final int expected ;
private final List < TextMessage > messagesToSend ;
private final List < TextMessage > actual = new CopyOnWriteArrayList < > ( ) ;
private final CountDownLatch latch ;
public TestClientWebSocketHandler ( int expectedNumberOfMessages , TextMessage . . . messagesToSend ) {
TestClientWebSocketHandler ( int expectedNumberOfMessages , TextMessage . . . messagesToSend ) {
this ( expectedNumberOfMessages , Arrays . asList ( messagesToSend ) ) ;
}
TestClientWebSocketHandler ( int expectedNumberOfMessages , List < TextMessage > messagesToSend ) {
this . messagesToSend = messagesToSend ;
this . expected = expectedNumberOfMessages ;
this . latch = new CountDownLatch ( this . expected ) ;
this . latch = new CountDownLatch ( expectedNumberOfMessages ) ;
}
@Override
@ -307,18 +341,20 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@@ -307,18 +341,20 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
}
@Override
protected void handleTextMessage ( WebSocketSession session , TextMessage message ) throws Exception {
this . actual . add ( message ) ;
this . latch . countDown ( ) ;
protected void handleTextMessage ( WebSocketSession session , TextMessage message ) {
if ( ! message . getPayload ( ) . startsWith ( "CONNECTED" ) ) {
this . actual . add ( message ) ;
this . latch . countDown ( ) ;
}
}
}
@Configuration
@ComponentScan (
basePackageClasses = StompWebSocketIntegrationTests . class ,
useDefaultFilters = false ,
includeFilters = @ComponentScan.Filter ( IntegrationTestController . class ) )
@EnableWebSocketMessageBroker
static class TestMessageBrokerConfigurer implements WebSocketMessageBrokerConfigurer {
@Autowired
@ -326,12 +362,14 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@@ -326,12 +362,14 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@Override
public void registerStompEndpoints ( StompEndpointRegistry registry ) {
registry . setPreserveReceiveOrder ( true ) ;
registry . addEndpoint ( "/ws" ) . setHandshakeHandler ( this . handshakeHandler ) ;
}
@Override
public void configureMessageBroker ( MessageBrokerRegistry configurer ) {
configurer . setApplicationDestinationPrefixes ( "/app" ) ;
configurer . setPreservePublishOrder ( true ) ;
configurer . enableSimpleBroker ( "/topic" , "/queue" ) . setSelectorHeaderName ( "selector" ) ;
}
@ -342,21 +380,4 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@@ -342,21 +380,4 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
}
}
@Configuration
static class TestMessageBrokerConfiguration extends DelegatingWebSocketMessageBrokerConfiguration {
@Override
@Bean
public AbstractSubscribableChannel clientInboundChannel ( TaskExecutor clientInboundChannelExecutor ) {
return new ExecutorSubscribableChannel ( ) ; // synchronous
}
@Override
@Bean
public AbstractSubscribableChannel clientOutboundChannel ( TaskExecutor clientOutboundChannelExecutor ) {
return new ExecutorSubscribableChannel ( ) ; // synchronous
}
}
}