From 4195e6906caf7cd743c6deb07892ea9589ccc961 Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Thu, 5 Oct 2023 16:42:20 +0100 Subject: [PATCH] OrderedMessageChannelDecorator supports multiple subscribers See gh-21798 --- .../OrderedMessageChannelDecorator.java | 18 ++++- .../OrderedMessageChannelDecoratorTests.java | 73 +++++++++++++------ 2 files changed, 63 insertions(+), 28 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java index 98ea4e422c..da3e7f4185 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecorator.java @@ -19,6 +19,7 @@ package org.springframework.messaging.simp.broker; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.logging.Log; @@ -48,6 +49,8 @@ public class OrderedMessageChannelDecorator implements MessageChannel { private final MessageChannel channel; + private final int subscriberCount; + private final Log logger; private final Queue> messages = new ConcurrentLinkedQueue<>(); @@ -57,6 +60,7 @@ public class OrderedMessageChannelDecorator implements MessageChannel { public OrderedMessageChannelDecorator(MessageChannel channel, Log logger) { this.channel = channel; + this.subscriberCount = (channel instanceof ExecutorSubscribableChannel ch ? ch.getSubscribers().size() : 0); this.logger = logger; } @@ -162,24 +166,30 @@ public class OrderedMessageChannelDecorator implements MessageChannel { /** * Remove handled message from queue, and send next message. */ - private class PostHandleTask implements Runnable { + private final class PostHandleTask implements Runnable { private final Message message; + @Nullable + private final AtomicInteger handledCount; + private PostHandleTask(Message message) { this.message = message; + this.handledCount = (subscriberCount > 1 ? new AtomicInteger(0) : null); } @Override public void run() { - if (OrderedMessageChannelDecorator.this.removeMessage(message)) { - sendNextMessage(); + if (this.handledCount == null || this.handledCount.addAndGet(1) == subscriberCount) { + if (OrderedMessageChannelDecorator.this.removeMessage(message)) { + sendNextMessage(); + } } } } - private static class CallbackTaskInterceptor implements ExecutorChannelInterceptor { + private final static class CallbackTaskInterceptor implements ExecutorChannelInterceptor { @Override public void afterMessageHandled( diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecoratorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecoratorTests.java index 4d8b8b636b..38449318ac 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecoratorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/OrderedMessageChannelDecoratorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,9 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.MessagingException; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.ExecutorSubscribableChannel; @@ -45,10 +48,6 @@ public class OrderedMessageChannelDecoratorTests { private static final Log logger = LogFactory.getLog(OrderedMessageChannelDecoratorTests.class); - private OrderedMessageChannelDecorator sender; - - ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(this.executor); - private ThreadPoolTaskExecutor executor; @@ -58,12 +57,6 @@ public class OrderedMessageChannelDecoratorTests { this.executor.setCorePoolSize(Runtime.getRuntime().availableProcessors() * 2); this.executor.setAllowCoreThreadTimeOut(true); this.executor.afterPropertiesSet(); - - this.channel = new ExecutorSubscribableChannel(this.executor); - OrderedMessageChannelDecorator.configureInterceptor(this.channel, true); - - this.sender = new OrderedMessageChannelDecorator(this.channel, logger); - } @AfterEach @@ -78,11 +71,48 @@ public class OrderedMessageChannelDecoratorTests { int start = 1; int end = 1000; - AtomicInteger index = new AtomicInteger(start); - AtomicReference result = new AtomicReference<>(); - CountDownLatch latch = new CountDownLatch(1); + ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel(this.executor); + OrderedMessageChannelDecorator.configureInterceptor(channel, true); + + TestHandler handler1 = new TestHandler(start, end); + TestHandler handler2 = new TestHandler(start, end); + TestHandler handler3 = new TestHandler(start, end); + + channel.subscribe(handler1); + channel.subscribe(handler2); + channel.subscribe(handler3); + + OrderedMessageChannelDecorator sender = new OrderedMessageChannelDecorator(channel, logger); + for (int i = start; i <= end; i++) { + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + accessor.setHeader("seq", i); + accessor.setLeaveMutable(true); + sender.send(MessageBuilder.createMessage("payload", accessor.getMessageHeaders())); + } + + handler1.verify(); + handler2.verify(); + handler3.verify(); + } - this.channel.subscribe(message -> { + + private static class TestHandler implements MessageHandler { + + private final AtomicInteger index; + + private final int end; + + private final AtomicReference result = new AtomicReference<>(); + + private final CountDownLatch latch = new CountDownLatch(1); + + TestHandler(int start, int end) { + this.index = new AtomicInteger(start); + this.end = end; + } + + @Override + public void handleMessage(Message message) throws MessagingException { int expected = index.getAndIncrement(); Integer actual = (Integer) message.getHeaders().getOrDefault("seq", -1); if (actual != expected) { @@ -104,17 +134,12 @@ public class OrderedMessageChannelDecoratorTests { result.set("Done"); latch.countDown(); } - }); - - for (int i = start; i <= end; i++) { - SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); - accessor.setHeader("seq", i); - accessor.setLeaveMutable(true); - this.sender.send(MessageBuilder.createMessage("payload", accessor.getMessageHeaders())); } - latch.await(10, TimeUnit.SECONDS); - assertThat(result.get()).isEqualTo("Done"); + void verify() throws InterruptedException { + assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); + assertThat(result.get()).isEqualTo("Done"); + } } }