Browse Source

Enforce cacheLimit in DefaultSubscriptionRegistry

When the cacheLimit is reached and there is an eviction from the
updateCache, the accessCache is now also updated.

This change also ensures that adding a destination to the cache is
protected with synchronization on the updateCache.

Issue: SPR-13555
pull/890/head
Rossen Stoyanchev 9 years ago
parent
commit
7ff915a01a
  1. 62
      spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java
  2. 74
      spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java

62
spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java

@ -177,31 +177,13 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -177,31 +177,13 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
}
@Override
protected MultiValueMap<String, String> findSubscriptionsInternal(String destination,
Message<?> message) {
LinkedMultiValueMap<String, String> result = this.destinationCache.getSubscriptions(destination);
if (result != null) {
return filterSubscriptions(result, message);
}
result = new LinkedMultiValueMap<String, String>();
for (SessionSubscriptionInfo info : this.subscriptionRegistry.getAllSubscriptions()) {
for (String destinationPattern : info.getDestinations()) {
if (this.pathMatcher.match(destinationPattern, destination)) {
for (Subscription subscription : info.getSubscriptions(destinationPattern)) {
result.add(info.sessionId, subscription.getId());
}
}
}
}
if (!result.isEmpty()) {
this.destinationCache.addSubscriptions(destination, result);
}
protected MultiValueMap<String, String> findSubscriptionsInternal(String destination, Message<?> message) {
MultiValueMap<String, String> result = this.destinationCache.getSubscriptions(destination, message);
return filterSubscriptions(result, message);
}
private MultiValueMap<String, String> filterSubscriptions(MultiValueMap<String, String> allMatches,
Message<?> message) {
private MultiValueMap<String, String> filterSubscriptions(
MultiValueMap<String, String> allMatches, Message<?> message) {
EvaluationContext context = null;
MultiValueMap<String, String> result = new LinkedMultiValueMap<String, String>(allMatches.size());
@ -264,20 +246,38 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @@ -264,20 +246,38 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
new LinkedHashMap<String, LinkedMultiValueMap<String, String>>(DEFAULT_CACHE_LIMIT, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<String, LinkedMultiValueMap<String, String>> eldest) {
return size() > getCacheLimit();
if (size() > getCacheLimit()) {
accessCache.remove(eldest.getKey());
return true;
}
else {
return false;
}
}
};
public LinkedMultiValueMap<String, String> getSubscriptions(String destination) {
return this.accessCache.get(destination);
}
public void addSubscriptions(String destination, LinkedMultiValueMap<String, String> subscriptions) {
synchronized (this.updateCache) {
this.updateCache.put(destination, subscriptions.deepCopy());
this.accessCache.put(destination, subscriptions);
public LinkedMultiValueMap<String, String> getSubscriptions(String destination, Message<?> message) {
LinkedMultiValueMap<String, String> result = this.accessCache.get(destination);
if (result == null) {
synchronized (this.updateCache) {
result = new LinkedMultiValueMap<String, String>();
for (SessionSubscriptionInfo info : subscriptionRegistry.getAllSubscriptions()) {
for (String destinationPattern : info.getDestinations()) {
if (getPathMatcher().match(destinationPattern, destination)) {
for (Subscription subscription : info.getSubscriptions(destinationPattern)) {
result.add(info.sessionId, subscription.getId());
}
}
}
}
if (!result.isEmpty()) {
this.updateCache.put(destination, result.deepCopy());
this.accessCache.put(destination, result);
}
}
}
return result;
}
public void updateAfterNewSubscription(String destination, String sessionId, String subsId) {

74
spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java

@ -21,9 +21,6 @@ import java.util.Collections; @@ -21,9 +21,6 @@ import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Test;
@ -31,11 +28,10 @@ import org.springframework.messaging.Message; @@ -31,11 +28,10 @@ import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.MultiValueMap;
import org.springframework.util.PathMatcher;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/**
* Test fixture for
@ -402,37 +398,22 @@ public class DefaultSubscriptionRegistryTests { @@ -402,37 +398,22 @@ public class DefaultSubscriptionRegistryTests {
// no ConcurrentModificationException
}
@Test // SPR-13204
public void findSubscriptionsWithConcurrentUnregisterAllSubscriptions() throws Exception {
final CountDownLatch iterationPausedLatch = new CountDownLatch(1);
final CountDownLatch iterationResumeLatch = new CountDownLatch(1);
final CountDownLatch iterationDoneLatch = new CountDownLatch(1);
PathMatcher pathMatcher = new PausingPathMatcher(iterationPausedLatch, iterationResumeLatch);
this.registry.setPathMatcher(pathMatcher);
@Test // SPR-13555
public void cacheLimitExceeded() throws Exception {
this.registry.setCacheLimit(1);
this.registry.registerSubscription(subscribeMessage("sess1", "1", "/foo"));
this.registry.registerSubscription(subscribeMessage("sess2", "1", "/foo"));
AtomicReference<MultiValueMap<String, String>> subscriptions = new AtomicReference<>();
new Thread(() -> {
subscriptions.set(registry.findSubscriptions(createMessage("/foo")));
iterationDoneLatch.countDown();
}).start();
this.registry.registerSubscription(subscribeMessage("sess1", "2", "/bar"));
assertTrue(iterationPausedLatch.await(10, TimeUnit.SECONDS));
assertEquals(1, this.registry.findSubscriptions(createMessage("/foo")).size());
assertEquals(1, this.registry.findSubscriptions(createMessage("/bar")).size());
this.registry.unregisterAllSubscriptions("sess1");
this.registry.unregisterAllSubscriptions("sess2");
iterationResumeLatch.countDown();
assertTrue(iterationDoneLatch.await(10, TimeUnit.SECONDS));
this.registry.registerSubscription(subscribeMessage("sess2", "1", "/foo"));
this.registry.registerSubscription(subscribeMessage("sess2", "2", "/bar"));
MultiValueMap<String, String> result = subscriptions.get();
assertNotNull(result);
assertEquals(0, result.size());
assertEquals(2, this.registry.findSubscriptions(createMessage("/foo")).size());
assertEquals(2, this.registry.findSubscriptions(createMessage("/bar")).size());
}
private Message<?> createMessage(String destination) {
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create();
accessor.setDestination(destination);
@ -468,35 +449,4 @@ public class DefaultSubscriptionRegistryTests { @@ -468,35 +449,4 @@ public class DefaultSubscriptionRegistryTests {
return list;
}
/**
* An extension of AntPathMatcher with a pair of CountDownLatches to pause
* while matching, allowing another thread to something, and resume when the
* other thread signals it's okay to do so.
*/
private static class PausingPathMatcher extends AntPathMatcher {
private final CountDownLatch iterationPausedLatch;
private final CountDownLatch iterationResumeLatch;
public PausingPathMatcher(CountDownLatch iterationPausedLatch, CountDownLatch iterationResumeLatch) {
this.iterationPausedLatch = iterationPausedLatch;
this.iterationResumeLatch = iterationResumeLatch;
}
@Override
public boolean match(String pattern, String path) {
try {
this.iterationPausedLatch.countDown();
assertTrue(this.iterationResumeLatch.await(10, TimeUnit.SECONDS));
return super.match(pattern, path);
}
catch (InterruptedException ex) {
ex.printStackTrace();
return false;
}
}
}
}

Loading…
Cancel
Save