Browse Source

Earlier detection of token authentication

Use a callback to detect token authentication (via inteceptor) thus
avoiding a potential race between that detection after the message is
sent on the inbound channel (via Executor) and the processing of the
CONNECTED frame returned from the broker on the outbound channel.

Closes gh-23160
pull/23837/head
Rossen Stoyanchev 6 years ago
parent
commit
4e6e47b726
  1. 22
      spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java
  2. 36
      spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessageHeaderAccessorTests.java
  3. 18
      spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
  4. 9
      spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java

22
spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ package org.springframework.messaging.simp; @@ -19,6 +19,7 @@ package org.springframework.messaging.simp;
import java.security.Principal;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
@ -84,6 +85,10 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { @@ -84,6 +85,10 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public static final String IGNORE_ERROR = "simpIgnoreError";
@Nullable
private Consumer<Principal> userCallback;
/**
* A constructor for creating new message headers.
* This constructor is protected. See factory methods in this and sub-classes.
@ -171,6 +176,9 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { @@ -171,6 +176,9 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public void setUser(@Nullable Principal principal) {
setHeader(USER_HEADER, principal);
if (this.userCallback != null) {
this.userCallback.accept(principal);
}
}
/**
@ -181,6 +189,18 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor { @@ -181,6 +189,18 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
return (Principal) getHeader(USER_HEADER);
}
/**
* Provide a callback to be invoked if and when {@link #setUser(Principal)}
* is called. This is used internally on the inbound channel to detect
* token-based authentications through an interceptor.
* @param callback the callback to invoke
* @since 5.1.9
*/
public void setUserChangeCallback(Consumer<Principal> callback) {
Assert.notNull(callback, "'callback' is required");
this.userCallback = this.userCallback != null ? this.userCallback.andThen(callback) : callback;
}
@Override
public String getShortLogMessage(Object payload) {
if (getMessageType() == null) {

36
spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessageHeaderAccessorTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,11 +16,14 @@ @@ -16,11 +16,14 @@
package org.springframework.messaging.simp;
import java.security.Principal;
import java.util.Collections;
import java.util.function.Consumer;
import org.junit.Test;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;
/**
* Unit tests for SimpMessageHeaderAccessor.
@ -63,4 +66,35 @@ public class SimpMessageHeaderAccessorTests { @@ -63,4 +66,35 @@ public class SimpMessageHeaderAccessorTests {
"{nativeKey=[nativeValue]} payload=p", accessor.getDetailedLogMessage("p"));
}
@Test
public void userChangeCallback() {
UserCallback userCallback = new UserCallback();
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create();
accessor.setUserChangeCallback(userCallback);
Principal user1 = mock(Principal.class);
accessor.setUser(user1);
assertEquals(user1, userCallback.getUser());
Principal user2 = mock(Principal.class);
accessor.setUser(user2);
assertEquals(user2, userCallback.getUser());
}
private static class UserCallback implements Consumer<Principal> {
private Principal user;
public Principal getUser() {
return this.user;
}
@Override
public void accept(Principal principal) {
this.user = principal;
}
}
}

18
spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

@ -258,9 +258,19 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -258,9 +258,19 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
Assert.state(headerAccessor != null, "No StompHeaderAccessor");
StompCommand command = headerAccessor.getCommand();
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
headerAccessor.setSessionId(session.getId());
headerAccessor.setSessionAttributes(session.getAttributes());
headerAccessor.setUser(getUser(session));
if (isConnect) {
headerAccessor.setUserChangeCallback(user -> {
if (user != null && user != session.getPrincipal()) {
this.stompAuthentications.put(session.getId(), user);
}
});
}
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
if (!detectImmutableMessageInterceptor(outputChannel)) {
headerAccessor.setImmutable();
@ -270,8 +280,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -270,8 +280,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
}
StompCommand command = headerAccessor.getCommand();
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
if (isConnect) {
this.stats.incrementConnectCount();
}
@ -284,12 +292,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE @@ -284,12 +292,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
boolean sent = outputChannel.send(message);
if (sent) {
if (isConnect) {
Principal user = headerAccessor.getUser();
if (user != null && user != session.getPrincipal()) {
this.stompAuthentications.put(session.getId(), user);
}
}
if (this.eventPublisher != null) {
Principal user = getUser(session);
if (isConnect) {

9
spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java

@ -378,6 +378,15 @@ public class StompSubProtocolHandlerTests { @@ -378,6 +378,15 @@ public class StompSubProtocolHandlerTests {
Principal user = SimpMessageHeaderAccessor.getUser(message.getHeaders());
assertNotNull(user);
assertEquals("__pete__@gmail.com", user.getName());
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
handler.handleMessageToClient(this.session, message);
assertEquals(1, this.session.getSentMessages().size());
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
assertEquals("CONNECTED\n" + "user-name:__pete__@gmail.com\n" + "\n" + "\u0000",
textMessage.getPayload());
}
@Test

Loading…
Cancel
Save