diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java index a0cd808bce..248361ccd9 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java @@ -177,31 +177,13 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { } @Override - protected MultiValueMap findSubscriptionsInternal(String destination, - Message message) { - - LinkedMultiValueMap result = this.destinationCache.getSubscriptions(destination); - if (result != null) { - return filterSubscriptions(result, message); - } - result = new LinkedMultiValueMap(); - for (SessionSubscriptionInfo info : this.subscriptionRegistry.getAllSubscriptions()) { - for (String destinationPattern : info.getDestinations()) { - if (this.pathMatcher.match(destinationPattern, destination)) { - for (Subscription subscription : info.getSubscriptions(destinationPattern)) { - result.add(info.sessionId, subscription.getId()); - } - } - } - } - if (!result.isEmpty()) { - this.destinationCache.addSubscriptions(destination, result); - } + protected MultiValueMap findSubscriptionsInternal(String destination, Message message) { + MultiValueMap result = this.destinationCache.getSubscriptions(destination, message); return filterSubscriptions(result, message); } - private MultiValueMap filterSubscriptions(MultiValueMap allMatches, - Message message) { + private MultiValueMap filterSubscriptions( + MultiValueMap allMatches, Message message) { EvaluationContext context = null; MultiValueMap result = new LinkedMultiValueMap(allMatches.size()); @@ -264,20 +246,38 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { new LinkedHashMap>(DEFAULT_CACHE_LIMIT, 0.75f, true) { @Override protected boolean removeEldestEntry(Map.Entry> eldest) { - return size() > getCacheLimit(); + if (size() > getCacheLimit()) { + accessCache.remove(eldest.getKey()); + return true; + } + else { + return false; + } } }; - public LinkedMultiValueMap getSubscriptions(String destination) { - return this.accessCache.get(destination); - } - - public void addSubscriptions(String destination, LinkedMultiValueMap subscriptions) { - synchronized (this.updateCache) { - this.updateCache.put(destination, subscriptions.deepCopy()); - this.accessCache.put(destination, subscriptions); + public LinkedMultiValueMap getSubscriptions(String destination, Message message) { + LinkedMultiValueMap result = this.accessCache.get(destination); + if (result == null) { + synchronized (this.updateCache) { + result = new LinkedMultiValueMap(); + for (SessionSubscriptionInfo info : subscriptionRegistry.getAllSubscriptions()) { + for (String destinationPattern : info.getDestinations()) { + if (getPathMatcher().match(destinationPattern, destination)) { + for (Subscription subscription : info.getSubscriptions(destinationPattern)) { + result.add(info.sessionId, subscription.getId()); + } + } + } + } + if (!result.isEmpty()) { + this.updateCache.put(destination, result.deepCopy()); + this.accessCache.put(destination, result); + } + } } + return result; } public void updateAfterNewSubscription(String destination, String sessionId, String subsId) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java index ff810b701c..e065a68d5a 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java @@ -21,9 +21,6 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import org.junit.Test; @@ -31,11 +28,10 @@ import org.springframework.messaging.Message; import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.messaging.simp.SimpMessageType; import org.springframework.messaging.support.MessageBuilder; -import org.springframework.util.AntPathMatcher; import org.springframework.util.MultiValueMap; -import org.springframework.util.PathMatcher; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; /** * Test fixture for @@ -402,37 +398,22 @@ public class DefaultSubscriptionRegistryTests { // no ConcurrentModificationException } - @Test // SPR-13204 - public void findSubscriptionsWithConcurrentUnregisterAllSubscriptions() throws Exception { - final CountDownLatch iterationPausedLatch = new CountDownLatch(1); - final CountDownLatch iterationResumeLatch = new CountDownLatch(1); - final CountDownLatch iterationDoneLatch = new CountDownLatch(1); - - PathMatcher pathMatcher = new PausingPathMatcher(iterationPausedLatch, iterationResumeLatch); - this.registry.setPathMatcher(pathMatcher); + @Test // SPR-13555 + public void cacheLimitExceeded() throws Exception { + this.registry.setCacheLimit(1); this.registry.registerSubscription(subscribeMessage("sess1", "1", "/foo")); - this.registry.registerSubscription(subscribeMessage("sess2", "1", "/foo")); - - AtomicReference> subscriptions = new AtomicReference<>(); - new Thread(() -> { - subscriptions.set(registry.findSubscriptions(createMessage("/foo"))); - iterationDoneLatch.countDown(); - }).start(); + this.registry.registerSubscription(subscribeMessage("sess1", "2", "/bar")); - assertTrue(iterationPausedLatch.await(10, TimeUnit.SECONDS)); + assertEquals(1, this.registry.findSubscriptions(createMessage("/foo")).size()); + assertEquals(1, this.registry.findSubscriptions(createMessage("/bar")).size()); - this.registry.unregisterAllSubscriptions("sess1"); - this.registry.unregisterAllSubscriptions("sess2"); - - iterationResumeLatch.countDown(); - assertTrue(iterationDoneLatch.await(10, TimeUnit.SECONDS)); + this.registry.registerSubscription(subscribeMessage("sess2", "1", "/foo")); + this.registry.registerSubscription(subscribeMessage("sess2", "2", "/bar")); - MultiValueMap result = subscriptions.get(); - assertNotNull(result); - assertEquals(0, result.size()); + assertEquals(2, this.registry.findSubscriptions(createMessage("/foo")).size()); + assertEquals(2, this.registry.findSubscriptions(createMessage("/bar")).size()); } - private Message createMessage(String destination) { SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(); accessor.setDestination(destination); @@ -468,35 +449,4 @@ public class DefaultSubscriptionRegistryTests { return list; } - - /** - * An extension of AntPathMatcher with a pair of CountDownLatches to pause - * while matching, allowing another thread to something, and resume when the - * other thread signals it's okay to do so. - */ - private static class PausingPathMatcher extends AntPathMatcher { - - private final CountDownLatch iterationPausedLatch; - - private final CountDownLatch iterationResumeLatch; - - public PausingPathMatcher(CountDownLatch iterationPausedLatch, CountDownLatch iterationResumeLatch) { - this.iterationPausedLatch = iterationPausedLatch; - this.iterationResumeLatch = iterationResumeLatch; - } - - @Override - public boolean match(String pattern, String path) { - try { - this.iterationPausedLatch.countDown(); - assertTrue(this.iterationResumeLatch.await(10, TimeUnit.SECONDS)); - return super.match(pattern, path); - } - catch (InterruptedException ex) { - ex.printStackTrace(); - return false; - } - } - } - }