diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java index a7193e8e86..9f0769af9c 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/user/DefaultUserDestinationResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -121,16 +121,17 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { @Override public UserDestinationResult resolveDestination(Message message) { - String sourceDestination = SimpMessageHeaderAccessor.getDestination(message.getHeaders()); ParseResult parseResult = parse(message); if (parseResult == null) { return null; } String user = parseResult.getUser(); + String sourceDestination = parseResult.getSourceDestination(); Set targetSet = new HashSet<>(); for (String sessionId : parseResult.getSessionIds()) { String actualDestination = parseResult.getActualDestination(); - String targetDestination = getTargetDestination(sourceDestination, actualDestination, sessionId, user); + String targetDestination = getTargetDestination( + sourceDestination, actualDestination, sessionId, user); if (targetDestination != null) { targetSet.add(targetDestination); } @@ -142,65 +143,84 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { @Nullable private ParseResult parse(Message message) { MessageHeaders headers = message.getHeaders(); - String destination = SimpMessageHeaderAccessor.getDestination(headers); - if (destination == null || !checkDestination(destination, this.prefix)) { + String sourceDestination = SimpMessageHeaderAccessor.getDestination(headers); + if (sourceDestination == null || !checkDestination(sourceDestination, this.prefix)) { return null; } SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers); + switch (messageType) { + case SUBSCRIBE: + case UNSUBSCRIBE: + return parseSubscriptionMessage(message, sourceDestination); + case MESSAGE: + return parseMessage(headers, sourceDestination); + default: + return null; + } + } + + private ParseResult parseSubscriptionMessage(Message message, String sourceDestination) { + MessageHeaders headers = message.getHeaders(); + String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); + if (sessionId == null) { + logger.error("No session id. Ignoring " + message); + return null; + } + int prefixEnd = this.prefix.length() - 1; + String actualDestination = sourceDestination.substring(prefixEnd); + if (!this.keepLeadingSlash) { + actualDestination = actualDestination.substring(1); + } Principal principal = SimpMessageHeaderAccessor.getUser(headers); + String user = (principal != null ? principal.getName() : null); + Set sessionIds = Collections.singleton(sessionId); + return new ParseResult(sourceDestination, actualDestination, sourceDestination, + sessionIds, user); + } + + private ParseResult parseMessage(MessageHeaders headers, String sourceDestination) { + int prefixEnd = this.prefix.length(); + int userEnd = sourceDestination.indexOf('/', prefixEnd); + Assert.isTrue(userEnd > 0, "Expected destination pattern \"/user/{userId}/**\""); + String actualDestination = sourceDestination.substring(userEnd); + String subscribeDestination = this.prefix.substring(0, prefixEnd - 1) + actualDestination; + String userName = sourceDestination.substring(prefixEnd, userEnd); + userName = StringUtils.replace(userName, "%2F", "/"); String sessionId = SimpMessageHeaderAccessor.getSessionId(headers); - if (SimpMessageType.SUBSCRIBE.equals(messageType) || SimpMessageType.UNSUBSCRIBE.equals(messageType)) { - if (sessionId == null) { - logger.error("No session id. Ignoring " + message); - return null; - } - int prefixEnd = this.prefix.length() - 1; - String actualDestination = destination.substring(prefixEnd); - if (!this.keepLeadingSlash) { - actualDestination = actualDestination.substring(1); - } - String user = (principal != null ? principal.getName() : null); - return new ParseResult(actualDestination, destination, Collections.singleton(sessionId), user); + Set sessionIds; + if (userName.equals(sessionId)) { + userName = null; + sessionIds = Collections.singleton(sessionId); + } + else { + sessionIds = getSessionIdsByUser(userName, sessionId); + } + if (!this.keepLeadingSlash) { + actualDestination = actualDestination.substring(1); } - else if (SimpMessageType.MESSAGE.equals(messageType)) { - int prefixEnd = this.prefix.length(); - int userEnd = destination.indexOf('/', prefixEnd); - Assert.isTrue(userEnd > 0, "Expected destination pattern \"/user/{userId}/**\""); - String actualDestination = destination.substring(userEnd); - String subscribeDestination = this.prefix.substring(0, prefixEnd - 1) + actualDestination; - String userName = destination.substring(prefixEnd, userEnd); - userName = StringUtils.replace(userName, "%2F", "/"); - Set sessionIds; - if (userName.equals(sessionId)) { - userName = null; + return new ParseResult(sourceDestination, actualDestination, subscribeDestination, + sessionIds, userName); + } + + private Set getSessionIdsByUser(String userName, String sessionId) { + Set sessionIds; + SimpUser user = this.userRegistry.getUser(userName); + if (user != null) { + if (user.getSession(sessionId) != null) { sessionIds = Collections.singleton(sessionId); } else { - SimpUser user = this.userRegistry.getUser(userName); - if (user != null) { - if (user.getSession(sessionId) != null) { - sessionIds = Collections.singleton(sessionId); - } - else { - Set sessions = user.getSessions(); - sessionIds = new HashSet<>(sessions.size()); - for (SimpSession session : sessions) { - sessionIds.add(session.getId()); - } - } - } - else { - sessionIds = Collections.emptySet(); + Set sessions = user.getSessions(); + sessionIds = new HashSet<>(sessions.size()); + for (SimpSession session : sessions) { + sessionIds.add(session.getId()); } } - if (!this.keepLeadingSlash) { - actualDestination = actualDestination.substring(1); - } - return new ParseResult(actualDestination, subscribeDestination, sessionIds, userName); } else { - return null; + sessionIds = Collections.emptySet(); } + return sessionIds; } protected boolean checkDestination(String destination, String requiredPrefix) { @@ -235,6 +255,8 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { */ private static class ParseResult { + private final String sourceDestination; + private final String actualDestination; private final String subscribeDestination; @@ -244,7 +266,10 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { private final String user; - public ParseResult(String actualDest, String subscribeDest, Set sessionIds, String user) { + public ParseResult(String sourceDest, String actualDest, String subscribeDest, + Set sessionIds, String user) { + + this.sourceDestination = sourceDest; this.actualDestination = actualDest; this.subscribeDestination = subscribeDest; this.sessionIds = sessionIds; @@ -252,6 +277,10 @@ public class DefaultUserDestinationResolver implements UserDestinationResolver { } + public String getSourceDestination() { + return this.sourceDestination; + } + public String getActualDestination() { return this.actualDestination; }