Browse Source

Handle STOMP messages from client in order

See gh-21798
pull/31416/head
rstoyanchev 1 year ago
parent
commit
a205eab618
  1. 7
      spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java
  2. 16
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompEndpointRegistry.java
  3. 9
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java
  4. 5
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java
  5. 37
      spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
  6. 105
      spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java

7
spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java

@ -113,7 +113,12 @@ public class OrderedMessageChannelDecorator implements MessageChannel { @@ -113,7 +113,12 @@ public class OrderedMessageChannelDecorator implements MessageChannel {
}
}
/**
* Remove the message from the top of the queue, but only if it matches,
* i.e. hasn't been removed already.
*/
private boolean removeMessage(Message<?> message) {
// Remove only if not removed already
Message<?> next = this.messages.peek();
if (next == message) {
this.messages.remove();
@ -181,7 +186,7 @@ public class OrderedMessageChannelDecorator implements MessageChannel { @@ -181,7 +186,7 @@ public class OrderedMessageChannelDecorator implements MessageChannel {
@Override
public void run() {
if (this.handledCount == null || this.handledCount.addAndGet(1) == subscriberCount) {
if (OrderedMessageChannelDecorator.this.removeMessage(message)) {
if (OrderedMessageChannelDecorator.this.removeMessage(this.message)) {
sendNextMessage();
}
}

16
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompEndpointRegistry.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -52,4 +52,18 @@ public interface StompEndpointRegistry { @@ -52,4 +52,18 @@ public interface StompEndpointRegistry {
*/
WebMvcStompEndpointRegistry setErrorHandler(StompSubProtocolErrorHandler errorHandler);
/**
* Whether to handle client messages sequentially in the order in which
* they were received.
* <p>By default messages sent to the {@code "clientInboundChannel"} may
* be handled in parallel and not in the same order as they were received
* because the channel is backed by a ThreadPoolExecutor that in turn does
* not guarantee processing in order.
* <p>When this flag is set to {@code true} messages within the same session
* will be sent to the {@code "clientInboundChannel"} one at a time in
* order to preserve the order in which they were received.
* @since 6.1
*/
WebMvcStompEndpointRegistry setPreserveReceiveOrder(boolean preserveReceiveOrder);
}

9
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompEndpointRegistry.java

@ -142,6 +142,15 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry { @@ -142,6 +142,15 @@ public class WebMvcStompEndpointRegistry implements StompEndpointRegistry {
return this;
}
public WebMvcStompEndpointRegistry setPreserveReceiveOrder(boolean preserveReceiveOrder) {
this.stompHandler.setPreserveReceiveOrder(preserveReceiveOrder);
return this;
}
protected boolean isPreserveReceiveOrder() {
return this.stompHandler.isPreserveReceiveOrder();
}
protected void setApplicationContext(ApplicationContext applicationContext) {
this.stompHandler.setApplicationEventPublisher(applicationContext);
}

5
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java

@ -28,6 +28,7 @@ import org.springframework.messaging.simp.SimpMessagingTemplate; @@ -28,6 +28,7 @@ import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.SimpSessionScope;
import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler;
import org.springframework.messaging.simp.broker.AbstractBrokerMessageHandler;
import org.springframework.messaging.simp.broker.OrderedMessageChannelDecorator;
import org.springframework.messaging.simp.config.AbstractMessageBrokerConfiguration;
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
import org.springframework.messaging.simp.user.SimpUserRegistry;
@ -80,7 +81,8 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac @@ -80,7 +81,8 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
@Bean
public HandlerMapping stompWebSocketHandlerMapping(
WebSocketHandler subProtocolWebSocketHandler, TaskScheduler messageBrokerTaskScheduler) {
WebSocketHandler subProtocolWebSocketHandler, TaskScheduler messageBrokerTaskScheduler,
AbstractSubscribableChannel clientInboundChannel) {
WebSocketHandler handler = decorateWebSocketHandler(subProtocolWebSocketHandler);
WebMvcStompEndpointRegistry registry =
@ -90,6 +92,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac @@ -90,6 +92,7 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
registry.setApplicationContext(applicationContext);
}
registerStompEndpoints(registry);
OrderedMessageChannelDecorator.configureInterceptor(clientInboundChannel, registry.isPreserveReceiveOrder());
return registry.getHandlerMapping();
}

37
spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

@ -108,6 +108,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -108,6 +108,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
@Nullable
private MessageHeaderInitializer headerInitializer;
private boolean preserveReceiveOrder;
private final Map<String, MessageChannel> messageChannels = new ConcurrentHashMap<>();
private final Map<String, Principal> stompAuthentications = new ConcurrentHashMap<>();
@Nullable
@ -193,6 +197,30 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -193,6 +197,30 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
return this.headerInitializer;
}
/**
* Whether client messages must be handled in the order received.
* <p>By default messages sent to the {@code "clientInboundChannel"} may
* not be handled in the same order because the channel is backed by a
* ThreadPoolExecutor that in turn does not guarantee processing in order.
* <p>When this flag is set to {@code true} messages within the same session
* will be sent to the {@code "clientInboundChannel"} one at a time to
* preserve the order in which they were received.
* @param preserveReceiveOrder whether to publish in order
* @since 6.1
*/
public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
this.preserveReceiveOrder = preserveReceiveOrder;
}
/**
* Whether the handler is configured to handle inbound messages in the
* order in which they were received.
* @since 6.1
*/
public boolean isPreserveReceiveOrder() {
return this.preserveReceiveOrder;
}
@Override
public List<String> getSupportedProtocols() {
return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
@ -268,6 +296,12 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -268,6 +296,12 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
return;
}
MessageChannel channelToUse =
(this.messageChannels.computeIfAbsent(session.getId(),
id -> this.preserveReceiveOrder ?
new OrderedMessageChannelDecorator(outputChannel, logger) :
outputChannel));
for (Message<byte[]> message : messages) {
StompHeaderAccessor headerAccessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
@ -307,7 +341,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -307,7 +341,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
try {
SimpAttributesContextHolder.setAttributesFromMessage(message);
sent = outputChannel.send(message);
sent = channelToUse.send(message);
if (sent) {
if (this.eventPublisher != null) {
@ -652,6 +686,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -652,6 +686,7 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
outputChannel.send(message);
}
finally {
this.messageChannels.remove(session.getId());
this.stompAuthentications.remove(session.getId());
SimpAttributesContextHolder.resetAttributes();
simpAttributes.sessionCompleted();

105
spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java

@ -18,6 +18,10 @@ package org.springframework.web.socket.messaging; @@ -18,6 +18,10 @@ package org.springframework.web.socket.messaging;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
@ -28,25 +32,22 @@ import org.junit.jupiter.api.TestInfo; @@ -28,25 +32,22 @@ import org.junit.jupiter.api.TestInfo;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Scope;
import org.springframework.context.annotation.ScopedProxyMode;
import org.springframework.core.task.TaskExecutor;
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.messaging.simp.annotation.SubscribeMapping;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.support.AbstractSubscribableChannel;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.stereotype.Controller;
import org.springframework.web.socket.AbstractWebSocketIntegrationTests;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.WebSocketTestServer;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.config.annotation.DelegatingWebSocketMessageBrokerConfiguration;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.handler.TextWebSocketHandler;
@ -69,7 +70,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -69,7 +70,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@Override
protected Class<?>[] getAnnotatedConfigClasses() {
return new Class<?>[] {TestMessageBrokerConfiguration.class, TestMessageBrokerConfigurer.class};
return new Class<?>[] {TestMessageBrokerConfigurer.class};
}
@ -100,7 +101,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -100,7 +101,7 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
TextMessage m2 = create(StompCommand.SEND)
.headers("destination:/app/increment").body("5").build();
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2);
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
@ -121,17 +122,49 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -121,17 +122,49 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destination, selector).build();
TextMessage m2 = create(StompCommand.SEND).headers(destination, "foo:bar").body("5").build();
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2);
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
String payload = clientHandler.actual.get(1).getPayload();
String payload = clientHandler.actual.get(0).getPayload();
assertThat(payload).as("Expected STOMP MESSAGE, got " + payload).startsWith("MESSAGE\n");
}
}
@ParameterizedWebSocketTest // gh-21798
void sendMessageToBrokerAndReceiveInOrder(
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
super.setup(server, webSocketClient, testInfo);
String destination = "destination:/topic/foo";
List<TextMessage> messages = new ArrayList<>();
messages.add(create(StompCommand.CONNECT).headers("accept-version:1.1").build());
messages.add(create(StompCommand.SUBSCRIBE).headers("id:subs1", destination).build());
int count = 1000;
for (int i = 0; i < count; i++) {
messages.add(create(StompCommand.SEND).headers(destination).body(String.valueOf(i)).build());
}
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(count, messages);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
for (int i = 0; i < count; i++) {
TextMessage message = clientHandler.actual.get(i);
ByteBuffer buffer = ByteBuffer.wrap(message.asBytes());
byte[] bytes = new StompDecoder().decode(buffer).get(0).getPayload();
assertThat(new String(bytes, StandardCharsets.UTF_8)).isEqualTo(String.valueOf(i));
}
}
}
@ParameterizedWebSocketTest // SPR-11648
void sendSubscribeToControllerAndReceiveReply(
WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
@ -142,12 +175,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -142,12 +175,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
String destHeader = "destination:/app/number";
TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destHeader).build();
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1);
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
String payload = clientHandler.actual.get(1).getPayload();
String payload = clientHandler.actual.get(0).getPayload();
assertThat(payload).as("Expected STOMP destination=/app/number, got " + payload).contains(destHeader);
assertThat(payload).as("Expected STOMP Payload=42, got " + payload).contains("42");
}
@ -164,12 +197,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -164,12 +197,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
TextMessage m1 = create(StompCommand.SUBSCRIBE).headers("id:subs1", destHeader).build();
TextMessage m2 = create(StompCommand.SEND).headers("destination:/app/exception").build();
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2);
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
String payload = clientHandler.actual.get(1).getPayload();
String payload = clientHandler.actual.get(0).getPayload();
assertThat(payload).startsWith("MESSAGE\n");
assertThat(payload).contains("destination:/user/queue/error\n");
assertThat(payload).endsWith("Got error: Bad input\0");
@ -188,12 +221,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -188,12 +221,12 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
TextMessage m2 = create(StompCommand.SEND)
.headers("destination:/app/scopedBeanValue").build();
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(2, m0, m1, m2);
TestClientWebSocketHandler clientHandler = new TestClientWebSocketHandler(1, m0, m1, m2);
try (WebSocketSession session = execute(clientHandler, "/ws").get()) {
assertThat(session).isNotNull();
assertThat(clientHandler.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();
String payload = clientHandler.actual.get(1).getPayload();
String payload = clientHandler.actual.get(0).getPayload();
assertThat(payload).startsWith("MESSAGE\n");
assertThat(payload).contains("destination:/topic/scopedBeanValue\n");
assertThat(payload).endsWith("55\0");
@ -285,18 +318,19 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -285,18 +318,19 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
private static class TestClientWebSocketHandler extends TextWebSocketHandler {
private final TextMessage[] messagesToSend;
private final int expected;
private final List<TextMessage> messagesToSend;
private final List<TextMessage> actual = new CopyOnWriteArrayList<>();
private final CountDownLatch latch;
public TestClientWebSocketHandler(int expectedNumberOfMessages, TextMessage... messagesToSend) {
TestClientWebSocketHandler(int expectedNumberOfMessages, TextMessage... messagesToSend) {
this(expectedNumberOfMessages, Arrays.asList(messagesToSend));
}
TestClientWebSocketHandler(int expectedNumberOfMessages, List<TextMessage> messagesToSend) {
this.messagesToSend = messagesToSend;
this.expected = expectedNumberOfMessages;
this.latch = new CountDownLatch(this.expected);
this.latch = new CountDownLatch(expectedNumberOfMessages);
}
@Override
@ -307,18 +341,20 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -307,18 +341,20 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
this.actual.add(message);
this.latch.countDown();
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
if (!message.getPayload().startsWith("CONNECTED")) {
this.actual.add(message);
this.latch.countDown();
}
}
}
@Configuration
@ComponentScan(
basePackageClasses = StompWebSocketIntegrationTests.class,
useDefaultFilters = false,
includeFilters = @ComponentScan.Filter(IntegrationTestController.class))
@EnableWebSocketMessageBroker
static class TestMessageBrokerConfigurer implements WebSocketMessageBrokerConfigurer {
@Autowired
@ -326,12 +362,14 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -326,12 +362,14 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.setPreserveReceiveOrder(true);
registry.addEndpoint("/ws").setHandshakeHandler(this.handshakeHandler);
}
@Override
public void configureMessageBroker(MessageBrokerRegistry configurer) {
configurer.setApplicationDestinationPrefixes("/app");
configurer.setPreservePublishOrder(true);
configurer.enableSimpleBroker("/topic", "/queue").setSelectorHeaderName("selector");
}
@ -342,21 +380,4 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests { @@ -342,21 +380,4 @@ class StompWebSocketIntegrationTests extends AbstractWebSocketIntegrationTests {
}
}
@Configuration
static class TestMessageBrokerConfiguration extends DelegatingWebSocketMessageBrokerConfiguration {
@Override
@Bean
public AbstractSubscribableChannel clientInboundChannel(TaskExecutor clientInboundChannelExecutor) {
return new ExecutorSubscribableChannel(); // synchronous
}
@Override
@Bean
public AbstractSubscribableChannel clientOutboundChannel(TaskExecutor clientOutboundChannelExecutor) {
return new ExecutorSubscribableChannel(); // synchronous
}
}
}

Loading…
Cancel
Save