Browse Source

Improve expired session check algorithm

1. Add session count threshold as am extra pre-condition.
2. Check pre-conditions for expiration checks on every request.

Effectively an upper bound on how many sessions can be created before
expiration checks are performed.

Issue: SPR-17020
pull/1783/merge
Rossen Stoyanchev 7 years ago
parent
commit
32b75221b3
  1. 103
      spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java
  2. 83
      spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java

103
spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java

@ -20,6 +20,7 @@ import java.time.Clock; @@ -20,6 +20,7 @@ import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@ -43,9 +44,6 @@ import org.springframework.web.server.WebSession; @@ -43,9 +44,6 @@ import org.springframework.web.server.WebSession;
*/
public class InMemoryWebSessionStore implements WebSessionStore {
/** Minimum period between expiration checks. */
private static final Duration EXPIRATION_CHECK_PERIOD = Duration.ofSeconds(60);
private static final IdGenerator idGenerator = new JdkIdGenerator();
@ -53,9 +51,7 @@ public class InMemoryWebSessionStore implements WebSessionStore { @@ -53,9 +51,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
private final ConcurrentMap<String, InMemoryWebSession> sessions = new ConcurrentHashMap<>();
private volatile Instant nextExpirationCheckTime = Instant.now(this.clock).plus(EXPIRATION_CHECK_PERIOD);
private final ReentrantLock expirationCheckLock = new ReentrantLock();
private final ExpiredSessionChecker expiredSessionChecker = new ExpiredSessionChecker();
/**
@ -70,8 +66,7 @@ public class InMemoryWebSessionStore implements WebSessionStore { @@ -70,8 +66,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
public void setClock(Clock clock) {
Assert.notNull(clock, "Clock is required");
this.clock = clock;
// Force a check when clock changes..
this.nextExpirationCheckTime = Instant.now(this.clock);
this.expiredSessionChecker.removeExpiredSessions(clock.instant());
}
/**
@ -84,49 +79,29 @@ public class InMemoryWebSessionStore implements WebSessionStore { @@ -84,49 +79,29 @@ public class InMemoryWebSessionStore implements WebSessionStore {
@Override
public Mono<WebSession> createWebSession() {
return Mono.fromSupplier(InMemoryWebSession::new);
Instant now = this.clock.instant();
this.expiredSessionChecker.checkIfNecessary(now);
return Mono.fromSupplier(() -> new InMemoryWebSession(now));
}
@Override
public Mono<WebSession> retrieveSession(String id) {
Instant currentTime = Instant.now(this.clock);
if (!this.sessions.isEmpty() && !currentTime.isBefore(this.nextExpirationCheckTime)) {
checkExpiredSessions(currentTime);
}
Instant now = this.clock.instant();
this.expiredSessionChecker.checkIfNecessary(now);
InMemoryWebSession session = this.sessions.get(id);
if (session == null) {
return Mono.empty();
}
else if (session.isExpired(currentTime)) {
else if (session.isExpired(now)) {
this.sessions.remove(id);
return Mono.empty();
}
else {
session.updateLastAccessTime(currentTime);
session.updateLastAccessTime(now);
return Mono.just(session);
}
}
private void checkExpiredSessions(Instant currentTime) {
if (this.expirationCheckLock.tryLock()) {
try {
Iterator<InMemoryWebSession> iterator = this.sessions.values().iterator();
while (iterator.hasNext()) {
InMemoryWebSession session = iterator.next();
if (session.isExpired(currentTime)) {
iterator.remove();
session.invalidate();
}
}
}
finally {
this.nextExpirationCheckTime = currentTime.plus(EXPIRATION_CHECK_PERIOD);
this.expirationCheckLock.unlock();
}
}
}
@Override
public Mono<Void> removeSession(String id) {
this.sessions.remove(id);
@ -137,7 +112,7 @@ public class InMemoryWebSessionStore implements WebSessionStore { @@ -137,7 +112,7 @@ public class InMemoryWebSessionStore implements WebSessionStore {
return Mono.fromSupplier(() -> {
Assert.isInstanceOf(InMemoryWebSession.class, webSession);
InMemoryWebSession session = (InMemoryWebSession) webSession;
session.updateLastAccessTime(Instant.now(getClock()));
session.updateLastAccessTime(getClock().instant());
return session;
});
}
@ -157,8 +132,9 @@ public class InMemoryWebSessionStore implements WebSessionStore { @@ -157,8 +132,9 @@ public class InMemoryWebSessionStore implements WebSessionStore {
private final AtomicReference<State> state = new AtomicReference<>(State.NEW);
public InMemoryWebSession() {
this.creationTime = Instant.now(getClock());
public InMemoryWebSession(Instant creationTime) {
this.creationTime = creationTime;
this.lastAccessTime = this.creationTime;
}
@ -256,6 +232,57 @@ public class InMemoryWebSessionStore implements WebSessionStore { @@ -256,6 +232,57 @@ public class InMemoryWebSessionStore implements WebSessionStore {
}
private class ExpiredSessionChecker {
/** Max time before next expiration checks. */
private static final int CHECK_PERIOD = 60;
/** Max sessions that can be created before next expiration checks. */
private static final int SESSION_COUNT_THRESHOLD = 500;
private final ReentrantLock lock = new ReentrantLock();
private Instant nextCheckTime = Instant.now(clock).plus(CHECK_PERIOD, ChronoUnit.SECONDS);
private long lastSessionCount;
public void checkIfNecessary(Instant now) {
if (howManyCreated() > SESSION_COUNT_THRESHOLD || this.nextCheckTime.isBefore(now)) {
removeExpiredSessions(Instant.now(clock));
}
}
private long howManyCreated() {
return sessions.size() - this.lastSessionCount;
}
public void removeExpiredSessions(Instant now) {
if (sessions.isEmpty()) {
return;
}
if (this.lock.tryLock()) {
try {
Iterator<InMemoryWebSession> iterator = sessions.values().iterator();
while (iterator.hasNext()) {
InMemoryWebSession session = iterator.next();
if (session.isExpired(now)) {
iterator.remove();
session.invalidate();
}
}
}
finally {
this.nextCheckTime = clock.instant().plus(CHECK_PERIOD, ChronoUnit.SECONDS);
this.lastSessionCount = sessions.size();
this.lock.unlock();
}
}
}
}
private enum State { NEW, STARTED, EXPIRED }
}

83
spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java

@ -18,15 +18,17 @@ package org.springframework.web.server.session; @@ -18,15 +18,17 @@ package org.springframework.web.server.session;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;
import org.junit.Test;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.web.server.WebSession;
import static junit.framework.TestCase.assertSame;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
/**
* Unit tests for {@link InMemoryWebSessionStore}.
@ -91,44 +93,57 @@ public class InMemoryWebSessionStoreTests { @@ -91,44 +93,57 @@ public class InMemoryWebSessionStoreTests {
}
@Test
public void expirationChecks() {
// Create 3 sessions
WebSession session1 = this.store.createWebSession().block();
assertNotNull(session1);
session1.start();
session1.save().block();
public void expirationCheckBasedOnTimeWindow() {
WebSession session2 = this.store.createWebSession().block();
assertNotNull(session2);
session2.start();
session2.save().block();
DirectFieldAccessor accessor = new DirectFieldAccessor(this.store);
Map<?,?> sessions = (Map<?, ?>) accessor.getPropertyValue("sessions");
WebSession session3 = this.store.createWebSession().block();
assertNotNull(session3);
session3.start();
session3.save().block();
// Create 100 sessions
IntStream.range(0, 100).forEach(i -> insertSession());
// Fast-forward 31 minutes
this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31)));
// Force a new clock (31 min later) but don't use setter which would clean expired sessions
Clock newClock = Clock.offset(this.store.getClock(), Duration.ofMinutes(31));
accessor.setPropertyValue("clock", newClock);
assertEquals(100, sessions.size());
// Create 50 more which forces a time-based check (clock moved forward)
IntStream.range(0, 50).forEach(i -> insertSession());
assertEquals(50, sessions.size());
}
@Test
@SuppressWarnings("unchecked")
public void expirationCheckBasedOnSessionCount() {
DirectFieldAccessor accessor = new DirectFieldAccessor(this.store);
Map<String, WebSession> sessions = (Map<String, WebSession>) accessor.getPropertyValue("sessions");
// Create 100 sessions
IntStream.range(0, 100).forEach(i -> insertSession());
// Create 2 more sessions
WebSession session4 = this.store.createWebSession().block();
assertNotNull(session4);
session4.start();
session4.save().block();
// Copy sessions (about to be expired)
Map<String, WebSession> expiredSessions = new HashMap<>(sessions);
// Set new clock which expires and removes above sessions
this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31)));
assertEquals(0, sessions.size());
WebSession session5 = this.store.createWebSession().block();
assertNotNull(session5);
session5.start();
session5.save().block();
// Re-insert expired sessions
sessions.putAll(expiredSessions);
assertEquals(100, sessions.size());
// Retrieve, forcing cleanup of all expired..
assertNull(this.store.retrieveSession(session1.getId()).block());
assertNull(this.store.retrieveSession(session2.getId()).block());
assertNull(this.store.retrieveSession(session3.getId()).block());
// Create 600 more to go over the threshold
IntStream.range(0, 600).forEach(i -> insertSession());
assertEquals(600, sessions.size());
}
assertNotNull(this.store.retrieveSession(session4.getId()).block());
assertNotNull(this.store.retrieveSession(session5.getId()).block());
private WebSession insertSession() {
WebSession session = this.store.createWebSession().block();
assertNotNull(session);
session.start();
session.save().block();
return session;
}
}

Loading…
Cancel
Save