From 4e6e47b726cacf983a6d630eeacf45d18be7e622 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 3 Jul 2019 15:22:56 +0100 Subject: [PATCH] Earlier detection of token authentication Use a callback to detect token authentication (via inteceptor) thus avoiding a potential race between that detection after the message is sent on the inbound channel (via Executor) and the processing of the CONNECTED frame returned from the broker on the outbound channel. Closes gh-23160 --- .../simp/SimpMessageHeaderAccessor.java | 22 +++++++++++- .../simp/SimpMessageHeaderAccessorTests.java | 36 ++++++++++++++++++- .../messaging/StompSubProtocolHandler.java | 18 +++++----- .../StompSubProtocolHandlerTests.java | 9 +++++ 4 files changed, 75 insertions(+), 10 deletions(-) diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java index b5d7d57e77..6c314f3329 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.messaging.simp; import java.security.Principal; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import org.springframework.lang.Nullable; import org.springframework.messaging.Message; @@ -84,6 +85,10 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { public static final String IGNORE_ERROR = "simpIgnoreError"; + @Nullable + private Consumer userCallback; + + /** * A constructor for creating new message headers. * This constructor is protected. See factory methods in this and sub-classes. @@ -171,6 +176,9 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { public void setUser(@Nullable Principal principal) { setHeader(USER_HEADER, principal); + if (this.userCallback != null) { + this.userCallback.accept(principal); + } } /** @@ -181,6 +189,18 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { return (Principal) getHeader(USER_HEADER); } + /** + * Provide a callback to be invoked if and when {@link #setUser(Principal)} + * is called. This is used internally on the inbound channel to detect + * token-based authentications through an interceptor. + * @param callback the callback to invoke + * @since 5.1.9 + */ + public void setUserChangeCallback(Consumer callback) { + Assert.notNull(callback, "'callback' is required"); + this.userCallback = this.userCallback != null ? this.userCallback.andThen(callback) : callback; + } + @Override public String getShortLogMessage(Object payload) { if (getMessageType() == null) { diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessageHeaderAccessorTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessageHeaderAccessorTests.java index 2b8fc17858..ce78d02991 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessageHeaderAccessorTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessageHeaderAccessorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,14 @@ package org.springframework.messaging.simp; +import java.security.Principal; import java.util.Collections; +import java.util.function.Consumer; import org.junit.Test; import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; /** * Unit tests for SimpMessageHeaderAccessor. @@ -63,4 +66,35 @@ public class SimpMessageHeaderAccessorTests { "{nativeKey=[nativeValue]} payload=p", accessor.getDetailedLogMessage("p")); } + @Test + public void userChangeCallback() { + UserCallback userCallback = new UserCallback(); + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(); + accessor.setUserChangeCallback(userCallback); + + Principal user1 = mock(Principal.class); + accessor.setUser(user1); + assertEquals(user1, userCallback.getUser()); + + Principal user2 = mock(Principal.class); + accessor.setUser(user2); + assertEquals(user2, userCallback.getUser()); + } + + + private static class UserCallback implements Consumer { + + private Principal user; + + + public Principal getUser() { + return this.user; + } + + @Override + public void accept(Principal principal) { + this.user = principal; + } + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 0c1be83101..32e923a90e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -258,9 +258,19 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); Assert.state(headerAccessor != null, "No StompHeaderAccessor"); + StompCommand command = headerAccessor.getCommand(); + boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command); + headerAccessor.setSessionId(session.getId()); headerAccessor.setSessionAttributes(session.getAttributes()); headerAccessor.setUser(getUser(session)); + if (isConnect) { + headerAccessor.setUserChangeCallback(user -> { + if (user != null && user != session.getPrincipal()) { + this.stompAuthentications.put(session.getId(), user); + } + }); + } headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat()); if (!detectImmutableMessageInterceptor(outputChannel)) { headerAccessor.setImmutable(); @@ -270,8 +280,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload())); } - StompCommand command = headerAccessor.getCommand(); - boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command); if (isConnect) { this.stats.incrementConnectCount(); } @@ -284,12 +292,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE boolean sent = outputChannel.send(message); if (sent) { - if (isConnect) { - Principal user = headerAccessor.getUser(); - if (user != null && user != session.getPrincipal()) { - this.stompAuthentications.put(session.getId(), user); - } - } if (this.eventPublisher != null) { Principal user = getUser(session); if (isConnect) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java index e93a8f1473..63b45da608 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java @@ -378,6 +378,15 @@ public class StompSubProtocolHandlerTests { Principal user = SimpMessageHeaderAccessor.getUser(message.getHeaders()); assertNotNull(user); assertEquals("__pete__@gmail.com", user.getName()); + + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED); + message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders()); + handler.handleMessageToClient(this.session, message); + + assertEquals(1, this.session.getSentMessages().size()); + WebSocketMessage textMessage = this.session.getSentMessages().get(0); + assertEquals("CONNECTED\n" + "user-name:__pete__@gmail.com\n" + "\n" + "\u0000", + textMessage.getPayload()); } @Test