diff --git a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java index e950209070..1d1ce7f58a 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java @@ -24,8 +24,6 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.util.Assert; -import org.springframework.util.IdGenerator; -import org.springframework.util.JdkIdGenerator; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebSession; @@ -36,13 +34,11 @@ import org.springframework.web.server.WebSession; * {@link WebSessionStore} * * @author Rossen Stoyanchev + * @author Rob Winch * @since 5.0 */ public class DefaultWebSessionManager implements WebSessionManager { - private static final IdGenerator idGenerator = new JdkIdGenerator(); - - private WebSessionIdResolver sessionIdResolver = new CookieWebSessionIdResolver(); private WebSessionStore sessionStore = new InMemoryWebSessionStore(); @@ -111,22 +107,20 @@ public class DefaultWebSessionManager implements WebSessionManager { return Mono.defer(() -> retrieveSession(exchange) .flatMap(session -> removeSessionIfExpired(exchange, session)) - .map(session -> { - Instant lastAccessTime = Instant.now(getClock()); - return new DefaultWebSession(session, lastAccessTime, s -> saveSession(exchange, s)); - }) + .flatMap(this.getSessionStore()::updateLastAccessTime) .switchIfEmpty(createSession(exchange)) + .cast(DefaultWebSession.class) + .map(session -> new DefaultWebSession(session, session.getLastAccessTime(), s -> saveSession(exchange, s))) .doOnNext(session -> exchange.getResponse().beforeCommit(session::save))); } - private Mono retrieveSession(ServerWebExchange exchange) { + private Mono retrieveSession(ServerWebExchange exchange) { return Flux.fromIterable(getSessionIdResolver().resolveSessionIds(exchange)) .concatMap(this.sessionStore::retrieveSession) - .cast(DefaultWebSession.class) .next(); } - private Mono removeSessionIfExpired(ServerWebExchange exchange, DefaultWebSession session) { + private Mono removeSessionIfExpired(ServerWebExchange exchange, WebSession session) { if (session.isExpired()) { this.sessionIdResolver.expireSession(exchange); return this.sessionStore.removeSession(session.getId()).then(Mono.empty()); @@ -162,11 +156,7 @@ public class DefaultWebSessionManager implements WebSessionManager { return ids.isEmpty() || !session.getId().equals(ids.get(0)); } - private Mono createSession(ServerWebExchange exchange) { - return Mono.fromSupplier(() -> - new DefaultWebSession(idGenerator, getClock(), - (oldId, session) -> this.sessionStore.changeSessionId(oldId, session), - session -> saveSession(exchange, session))); + private Mono createSession(ServerWebExchange exchange) { + return this.sessionStore.createWebSession(); } - } diff --git a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java index dd193faf5e..bd2a7a97ba 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java @@ -15,9 +15,15 @@ */ package org.springframework.web.server.session; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneId; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import org.springframework.util.Assert; +import org.springframework.util.IdGenerator; +import org.springframework.util.JdkIdGenerator; import reactor.core.publisher.Mono; import org.springframework.web.server.WebSession; @@ -26,18 +32,17 @@ import org.springframework.web.server.WebSession; * Simple Map-based storage for {@link WebSession} instances. * * @author Rossen Stoyanchev + * @author Rob Winch * @since 5.0 */ public class InMemoryWebSessionStore implements WebSessionStore { - private final Map sessions = new ConcurrentHashMap<>(); + private static final IdGenerator idGenerator = new JdkIdGenerator(); + private Clock clock = Clock.system(ZoneId.of("GMT")); + + private final Map sessions = new ConcurrentHashMap<>(); - @Override - public Mono storeSession(WebSession session) { - this.sessions.put(session.getId(), session); - return Mono.empty(); - } @Override public Mono retrieveSession(String id) { @@ -45,16 +50,55 @@ public class InMemoryWebSessionStore implements WebSessionStore { } @Override - public Mono changeSessionId(String oldId, WebSession session) { + public Mono removeSession(String id) { + this.sessions.remove(id); + return Mono.empty(); + } + + public Mono createWebSession() { + return Mono.fromSupplier(() -> + new DefaultWebSession(idGenerator, getClock(), + (oldId, session) -> this.changeSessionId(oldId, session), + this::storeSession)); + } + + public Mono updateLastAccessTime(WebSession webSession) { + return Mono.fromSupplier(() -> { + DefaultWebSession session = (DefaultWebSession) webSession; + Instant lastAccessTime = Instant.now(getClock()); + return new DefaultWebSession(session, lastAccessTime); + }); + } + + /** + * Configure the {@link Clock} to use to set lastAccessTime on every created + * session and to calculate if it is expired. + *

This may be useful to align to different timezone or to set the clock + * back in a test, e.g. {@code Clock.offset(clock, Duration.ofMinutes(-31))} + * in order to simulate session expiration. + *

By default this is {@code Clock.system(ZoneId.of("GMT"))}. + * @param clock the clock to use + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "'clock' is required."); + this.clock = clock; + } + + /** + * Return the configured clock for session lastAccessTime calculations. + */ + public Clock getClock() { + return this.clock; + } + + private Mono changeSessionId(String oldId, WebSession session) { this.sessions.remove(oldId); this.sessions.put(session.getId(), session); return Mono.empty(); } - @Override - public Mono removeSession(String id) { - this.sessions.remove(id); + private Mono storeSession(WebSession session) { + this.sessions.put(session.getId(), session); return Mono.empty(); } - } diff --git a/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java index b52cafe546..26c064153b 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java @@ -19,20 +19,22 @@ import reactor.core.publisher.Mono; import org.springframework.web.server.WebSession; +import java.time.Instant; + /** * Strategy for {@link WebSession} persistence. * * @author Rossen Stoyanchev + * @author Rob Winch * @since 5.0 */ public interface WebSessionStore { /** - * Store the given WebSession. - * @param session the session to store - * @return a completion notification (success or error) + * Creates the WebSession that can be stored by this WebSessionStore. + * @return the session */ - Mono storeSession(WebSession session); + Mono createWebSession(); /** * Return the WebSession for the given id. @@ -41,18 +43,6 @@ public interface WebSessionStore { */ Mono retrieveSession(String sessionId); - /** - * Update WebSession data storage to reflect a change in session id. - *

Note that the same can be achieved via a combination of - * {@link #removeSession} + {@link #storeSession}. The purpose of this method - * is to allow a more efficient replacement of the session id mapping - * without replacing and storing the session with all of its data. - * @param oldId the previous session id - * @param session the session reflecting the changed session id - * @return completion notification (success or error) - */ - Mono changeSessionId(String oldId, WebSession session); - /** * Remove the WebSession for the specified id. * @param sessionId the id of the session to remove @@ -60,4 +50,10 @@ public interface WebSessionStore { */ Mono removeSession(String sessionId); + /** + * Update the last accessed time to now. + * @param webSession the session to update + * @return the session with the updated last access time + */ + Mono updateLastAccessTime(WebSession webSession); } diff --git a/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java b/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java index 7d7501e32a..290c8dac53 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java @@ -74,6 +74,9 @@ public class DefaultWebSessionManagerTests { @Before public void setUp() throws Exception { + when(this.store.createWebSession()).thenReturn(Mono.just(createDefaultWebSession())); + when(this.store.updateLastAccessTime(any())).thenAnswer( invocation -> Mono.just(invocation.getArgument(0))); + this.manager = new DefaultWebSessionManager(); this.manager.setSessionIdResolver(this.idResolver); this.manager.setSessionStore(this.store); @@ -106,6 +109,7 @@ public class DefaultWebSessionManagerTests { session.save().block(); String id = session.getId(); + verify(this.store).createWebSession(); verify(this.store).storeSession(any()); verify(this.idResolver).setSessionId(any(), eq(id)); } @@ -118,6 +122,7 @@ public class DefaultWebSessionManagerTests { session.getAttributes().put("foo", "bar"); session.save().block(); + verify(this.store).createWebSession(); verify(this.idResolver).setSessionId(any(), any()); verify(this.store).storeSession(any()); } diff --git a/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java index 705044447f..be20377025 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java @@ -115,7 +115,7 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe assertNotNull(session); Instant lastAccessTime = Clock.offset(this.sessionManager.getClock(), Duration.ofMinutes(-31)).instant(); session = new DefaultWebSession(session, lastAccessTime); - store.storeSession(session); + session.save().block(); // Third request: expired session, new session created request = RequestEntity.get(createUri()).header("Cookie", "SESSION=" + id).build();