diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java index 96fb329e72..16b59ed57e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java @@ -61,6 +61,11 @@ public interface WebSocketSession { */ String getRemoteAddress(); + /** + * Return the negotiated sub-protocol or {@code null} if none was specified. + */ + String getAcceptedProtocol(); + /** * Return whether the connection is still open. */ diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/ConfigurableWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/ConfigurableWebSocketSession.java index 17e8a53da9..47cae03678 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/ConfigurableWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/ConfigurableWebSocketSession.java @@ -20,6 +20,7 @@ import java.net.URI; import java.security.Principal; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.server.DefaultHandshakeHandler; /** * A WebSocketSession with configurable properties. @@ -37,4 +38,12 @@ public interface ConfigurableWebSocketSession extends WebSocketSession { void setPrincipal(Principal principal); + /** + * Set the protocol accepted as part of the WebSocket handshake. This property can be + * used when the WebSocket handshake is performed through + * {@link DefaultHandshakeHandler} rather than the underlying WebSocket runtime, or + * when there is no WebSocket handshake (e.g. SockJS HTTP fallback options) + */ + void setAcceptedProtocol(String protocol); + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSessionAdapter.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSessionAdapter.java index c34ff2ba1a..525f3de6fa 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSessionAdapter.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSessionAdapter.java @@ -22,6 +22,7 @@ import java.net.URI; import java.security.Principal; import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.api.UpgradeResponse; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; import org.springframework.web.socket.BinaryMessage; @@ -44,11 +45,20 @@ public class JettyWebSocketSessionAdapter private Principal principal; + private String protocol; + @Override public void initSession(Session session) { Assert.notNull(session, "session must not be null"); this.session = session; + + if (this.protocol == null) { + UpgradeResponse response = session.getUpgradeResponse(); + if ((response != null) && response.getAcceptedSubProtocol() != null) { + this.protocol = response.getAcceptedSubProtocol(); + } + } } @Override @@ -101,6 +111,16 @@ public class JettyWebSocketSessionAdapter // ignore } + @Override + public String getAcceptedProtocol() { + return this.protocol; + } + + @Override + public void setAcceptedProtocol(String protocol) { + this.protocol = protocol; + } + @Override public boolean isOpen() { return this.session.isOpen(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSessionAdapter.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSessionAdapter.java index 48c4d3c8f1..131105c195 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSessionAdapter.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSessionAdapter.java @@ -24,6 +24,7 @@ import javax.websocket.CloseReason; import javax.websocket.CloseReason.CloseCodes; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; import org.springframework.web.socket.BinaryMessage; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; @@ -45,11 +46,19 @@ public class StandardWebSocketSessionAdapter extends AbstractWebSocketSesssionAd private String remoteAddress; + private String protocol; + @Override public void initSession(javax.websocket.Session session) { Assert.notNull(session, "session must not be null"); this.session = session; + + if (this.protocol == null) { + if (StringUtils.hasText(session.getNegotiatedSubprotocol())) { + this.protocol = session.getNegotiatedSubprotocol(); + } + } } @Override @@ -103,6 +112,16 @@ public class StandardWebSocketSessionAdapter extends AbstractWebSocketSesssionAd this.remoteAddress = address; } + @Override + public String getAcceptedProtocol() { + return this.protocol; + } + + @Override + public void setAcceptedProtocol(String protocol) { + this.protocol = protocol; + } + @Override public boolean isOpen() { return this.session.isOpen(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java index ec2e7e7e8e..05475115e8 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java @@ -43,7 +43,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport { private WebSocketSession webSocketSession; - private final List subProtocols = new ArrayList(); + private final List protocols = new ArrayList(); private final boolean syncClientLifecycle; @@ -67,15 +67,15 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport { return new LoggingWebSocketHandlerDecorator(handler); } - public void setSubProtocols(List subProtocols) { - this.subProtocols.clear(); - if (!CollectionUtils.isEmpty(subProtocols)) { - this.subProtocols.addAll(subProtocols); + public void setSupportedProtocols(List protocols) { + this.protocols.clear(); + if (!CollectionUtils.isEmpty(protocols)) { + this.protocols.addAll(protocols); } } - public List getSubProtocols() { - return this.subProtocols; + public List getSupportedProtocols() { + return this.protocols; } @Override @@ -97,7 +97,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport { @Override protected void openConnection() throws Exception { HttpHeaders headers = new HttpHeaders(); - headers.setSecWebSocketProtocol(this.subProtocols); + headers.setSecWebSocketProtocol(this.protocols); this.webSocketSession = this.client.doHandshake(this.webSocketHandler, headers, getUri()); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/EndpointConnectionManager.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/EndpointConnectionManager.java index 25084c8c54..03fcdc97f8 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/EndpointConnectionManager.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/EndpointConnectionManager.java @@ -74,8 +74,8 @@ public class EndpointConnectionManager extends ConnectionManagerSupport implemen } - public void setSubProtocols(String... subprotocols) { - this.configBuilder.preferredSubprotocols(Arrays.asList(subprotocols)); + public void setSupportedProtocols(String... protocols) { + this.configBuilder.preferredSubprotocols(Arrays.asList(protocols)); } public void setExtensions(Extension... extensions) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java index 65e4e07f31..e57c0ec498 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java @@ -21,7 +21,6 @@ import java.nio.charset.Charset; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -35,7 +34,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.ClassUtils; -import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.socket.WebSocketHandler; @@ -55,7 +53,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { protected Log logger = LogFactory.getLog(getClass()); - private List supportedProtocols = new ArrayList(); + private final List supportedProtocols = new ArrayList(); private final RequestUpgradeStrategy requestUpgradeStrategy; @@ -78,11 +76,22 @@ public class DefaultHandshakeHandler implements HandshakeHandler { this.requestUpgradeStrategy = upgradeStrategy; } - + /** + * Use this property to configure a list of sub-protocols that are supported. + * The first protocol that matches what the client requested is selected. + * If no protocol matches or this property is not configured, then the + * response will not contain a Sec-WebSocket-Protocol header. + */ public void setSupportedProtocols(String... protocols) { - this.supportedProtocols = Arrays.asList(protocols); + this.supportedProtocols.clear(); + for (String protocol : protocols) { + this.supportedProtocols.add(protocol.toLowerCase()); + } } + /** + * Return the list of supported sub-protocols. + */ public String[] getSupportedProtocols() { return this.supportedProtocols.toArray(new String[this.supportedProtocols.size()]); } @@ -191,9 +200,12 @@ public class DefaultHandshakeHandler implements HandshakeHandler { } protected String selectProtocol(List requestedProtocols) { - if (CollectionUtils.isEmpty(requestedProtocols)) { + if (requestedProtocols != null) { for (String protocol : requestedProtocols) { - if (this.supportedProtocols.contains(protocol)) { + if (this.supportedProtocols.contains(protocol.toLowerCase())) { + if (logger.isDebugEnabled()) { + logger.debug("Selected sub-protocol '" + protocol + "'"); + } return protocol; } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/ServerEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/ServerEndpointRegistration.java index 1f4075ffb2..53a9ad35f0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/ServerEndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/ServerEndpointRegistration.java @@ -60,7 +60,7 @@ public class ServerEndpointRegistration implements ServerEndpointConfig, BeanFac private List> decoders = new ArrayList>(); - private List subprotocols = new ArrayList(); + private List protocols = new ArrayList(); private List extensions = new ArrayList(); @@ -113,13 +113,13 @@ public class ServerEndpointRegistration implements ServerEndpointConfig, BeanFac return (this.endpoint != null) ? this.endpoint : this.endpointProvider.getHandler(); } - public void setSubprotocols(List subprotocols) { - this.subprotocols = subprotocols; + public void setSubprotocols(List protocols) { + this.protocols = protocols; } @Override public List getSubprotocols() { - return this.subprotocols; + return this.protocols; } public void setExtensions(List extensions) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java index 7efc8a458b..14d1764359 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java @@ -48,7 +48,7 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS String protocol, WebSocketHandler handler) throws IOException, HandshakeFailureException { StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter(); - this.wsSessionInitializer.initialize(request, response, session); + this.wsSessionInitializer.initialize(request, response, protocol, session); StandardEndpointAdapter endpoint = new StandardEndpointAdapter(handler, session); upgradeInternal(request, response, protocol, endpoint); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java index 63cd67a60a..8690fe6f33 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java @@ -87,7 +87,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, - String selectedProtocol, WebSocketHandler webSocketHandler) throws IOException { + String protocol, WebSocketHandler webSocketHandler) throws IOException { Assert.isInstanceOf(ServletServerHttpRequest.class, request); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); @@ -101,7 +101,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { } JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter(); - this.wsSessionInitializer.initialize(request, response, session); + this.wsSessionInitializer.initialize(request, response, protocol, session); JettyWebSocketListenerAdapter listener = new JettyWebSocketListenerAdapter(webSocketHandler, session); servletRequest.setAttribute(WEBSOCKET_LISTENER_ATTR_NAME, listener); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/ServerWebSocketSessionInitializer.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/ServerWebSocketSessionInitializer.java index 7be6c1309a..50fef1e744 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/ServerWebSocketSessionInitializer.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/ServerWebSocketSessionInitializer.java @@ -30,11 +30,14 @@ import org.springframework.web.socket.adapter.ConfigurableWebSocketSession; */ public class ServerWebSocketSessionInitializer { - public void initialize(ServerHttpRequest request, ServerHttpResponse response, ConfigurableWebSocketSession session) { + public void initialize(ServerHttpRequest request, ServerHttpResponse response, + String protocol, ConfigurableWebSocketSession session) { + session.setUri(request.getURI()); session.setRemoteHostName(request.getRemoteHostName()); session.setRemoteAddress(request.getRemoteAddress()); session.setPrincipal(request.getPrincipal()); + session.setAcceptedProtocol(protocol); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/DefaultSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/DefaultSockJsService.java index 5e8c42097f..57a3d34415 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/DefaultSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/DefaultSockJsService.java @@ -245,7 +245,8 @@ public class DefaultSockJsService extends AbstractSockJsService { } logger.debug("Creating new session with session id \"" + sessionId + "\""); session = sessionFactory.createSession(sessionId, handler); - this.sessionInitializer.initialize(request, response, session); + String protocol = null; // TODO: https://github.com/sockjs/sockjs-client/issues/130 + this.sessionInitializer.initialize(request, response, protocol, session); this.sessions.put(sessionId, session); return session; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpReceivingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpReceivingTransportHandler.java index ba92c58fa8..ce98eb55f2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpReceivingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpReceivingTransportHandler.java @@ -63,6 +63,9 @@ public abstract class AbstractHttpReceivingTransportHandler implements Transport return; } + // TODO: check "Sec-WebSocket-Protocol" header + // https://github.com/sockjs/sockjs-client/issues/130 + handleRequestInternal(request, response, session); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpSockJsSession.java index 43827d415c..e5e9482713 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpSockJsSession.java @@ -49,12 +49,32 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private ServerHttpResponse response; + private String protocol; + public AbstractHttpSockJsSession(String sessionId, SockJsConfiguration config, WebSocketHandler handler) { super(sessionId, config, handler); } + /** + * Unlike WebSocket where sub-protocol negotiation is part of the + * initial handshake, in HTTP transports the same negotiation must + * be emulated and the selected protocol set through this setter. + * + * @param protocol the sub-protocol to set + */ + public void setAcceptedProtocol(String protocol) { + this.protocol = protocol; + } + + /** + * Return the selected sub-protocol to use. + */ + public String getAcceptedProtocol() { + return this.protocol; + } + public synchronized void setInitialRequest(ServerHttpRequest request, ServerHttpResponse response, FrameFormat frameFormat) throws TransportErrorException { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/WebSocketServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/WebSocketServerSockJsSession.java index d8b451492c..639bfbeba0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/WebSocketServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/WebSocketServerSockJsSession.java @@ -48,6 +48,20 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession { } + @Override + public String getAcceptedProtocol() { + if (this.webSocketSession == null) { + logger.warn("getAcceptedProtocol() invoked before WebSocketSession has been initialized."); + return null; + } + return this.webSocketSession.getAcceptedProtocol(); + } + + @Override + public void setAcceptedProtocol(String protocol) { + // ignore, webSocketSession should have it + } + public void initWebSocketSession(WebSocketSession session) throws Exception { this.webSocketSession = session; try { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandler.java new file mode 100644 index 0000000000..c579e34f51 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandler.java @@ -0,0 +1,125 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.support; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.springframework.util.Assert; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.WebSocketSession; + + +/** + * A {@link WebSocketHandler} that delegates to other {@link WebSocketHandler} instances + * based on the sub-protocol value accepted at the handshake. A default handler can also + * be configured for use by default when a sub-protocol value if the WebSocket session + * does not have a sub-protocol value associated with it. + * + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class MultiProtocolWebSocketHandler implements WebSocketHandler { + + private WebSocketHandler defaultHandler; + + private Map handlers = new HashMap(); + + + /** + * Configure {@link WebSocketHandler}'s to use by sub-protocol. The values for + * sub-protocols are case insensitive. + */ + public void setProtocolHandlers(Map protocolHandlers) { + this.handlers.clear(); + for (String protocol : protocolHandlers.keySet()) { + this.handlers.put(protocol.toLowerCase(), protocolHandlers.get(protocol)); + } + } + + /** + * Return a read-only copy of the sub-protocol handler map. + */ + public Map getProtocolHandlers() { + return Collections.unmodifiableMap(this.handlers); + } + + /** + * Set the default {@link WebSocketHandler} to use if a sub-protocol was not + * requested. + */ + public void setDefaultProtocolHandler(WebSocketHandler defaultHandler) { + this.defaultHandler = defaultHandler; + } + + /** + * Return the default {@link WebSocketHandler} to be used. + */ + public WebSocketHandler getDefaultProtocolHandler() { + return this.defaultHandler; + } + + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + WebSocketHandler handler = getHandlerForSession(session); + handler.afterConnectionEstablished(session); + } + + private WebSocketHandler getHandlerForSession(WebSocketSession session) { + WebSocketHandler handler = null; + String protocol = session.getAcceptedProtocol(); + if (protocol != null) { + handler = this.handlers.get(protocol.toLowerCase()); + Assert.state(handler != null, + "No WebSocketHandler for sub-protocol '" + protocol + "', handlers=" + this.handlers); + } + else { + handler = this.defaultHandler; + Assert.state(handler != null, + "No sub-protocol was requested and no default WebSocketHandler was configured"); + } + return handler; + } + + @Override + public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { + WebSocketHandler handler = getHandlerForSession(session); + handler.handleMessage(session, message); + } + + @Override + public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { + WebSocketHandler handler = getHandlerForSession(session); + handler.handleTransportError(session, exception); + } + + @Override + public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { + WebSocketHandler handler = getHandlerForSession(session); + handler.afterConnectionClosed(session, closeStatus); + } + + @Override + public boolean supportsPartialMessages() { + return false; + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java index 8e08674a2b..77c7a66147 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java @@ -51,7 +51,7 @@ public class WebSocketConnectionManagerTests { WebSocketHandler handler = new WebSocketHandlerAdapter(); WebSocketConnectionManager manager = new WebSocketConnectionManager(client, handler , "/path/{id}", "123"); - manager.setSubProtocols(subprotocols); + manager.setSupportedProtocols(subprotocols); manager.openConnection(); ArgumentCaptor captor = ArgumentCaptor.forClass(WebSocketHandlerDecorator.class); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java new file mode 100644 index 0000000000..c47ed6906d --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.springframework.web.socket.AbstractHttpRequestTests; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter; + +import static org.mockito.Mockito.*; + + +/** + * Test fixture for {@link DefaultHandshakeHandler}. + * + * @author Rossen Stoyanchev + */ +public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { + + private DefaultHandshakeHandler handshakeHandler; + + @Mock + private RequestUpgradeStrategy upgradeStrategy; + + + @Before + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + this.handshakeHandler = new DefaultHandshakeHandler(this.upgradeStrategy); + } + + + @Test + public void selectSubProtocol() throws Exception { + + this.handshakeHandler.setSupportedProtocols("stomp", "mqtt"); + + when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] { "13" }); + + this.servletRequest.setMethod("GET"); + this.request.getHeaders().setUpgrade("WebSocket"); + this.request.getHeaders().setConnection("Upgrade"); + this.request.getHeaders().setSecWebSocketVersion("13"); + this.request.getHeaders().setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); + this.request.getHeaders().setSecWebSocketProtocol("STOMP"); + + WebSocketHandler handler = new TextWebSocketHandlerAdapter(); + + this.handshakeHandler.doHandshake(this.request, this.response, handler); + + verify(this.upgradeStrategy).upgrade(request, response, "STOMP", handler); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/TestSockJsSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/TestSockJsSession.java index 684db0d116..c861f8e87a 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/TestSockJsSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/TestSockJsSession.java @@ -40,11 +40,23 @@ public class TestSockJsSession extends AbstractSockJsSession { private boolean cancelledHeartbeat; + private String subProtocol; + public TestSockJsSession(String sessionId, SockJsConfiguration config, WebSocketHandler handler) { super(sessionId, config, handler); } + @Override + public String getAcceptedProtocol() { + return this.subProtocol; + } + + @Override + public void setAcceptedProtocol(String protocol) { + this.subProtocol = protocol; + } + public CloseStatus getStatus() { return this.status; } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandlerTests.java new file mode 100644 index 0000000000..533fce706d --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandlerTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIOsNS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.support; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.springframework.web.socket.WebSocketHandler; + +import static org.mockito.Mockito.*; + + +/** + * Test fixture for {@link MultiProtocolWebSocketHandler}. + * + * @author Rossen Stoyanchev + */ +public class MultiProtocolWebSocketHandlerTests { + + private MultiProtocolWebSocketHandler multiProtocolHandler; + + @Mock + WebSocketHandler stompHandler; + + @Mock + WebSocketHandler mqttHandler; + + @Mock + WebSocketHandler defaultHandler; + + + @Before + public void setup() { + + MockitoAnnotations.initMocks(this); + + Map handlers = new HashMap(); + handlers.put("STOMP", this.stompHandler); + handlers.put("MQTT", this.mqttHandler); + + this.multiProtocolHandler = new MultiProtocolWebSocketHandler(); + this.multiProtocolHandler.setProtocolHandlers(handlers); + this.multiProtocolHandler.setDefaultProtocolHandler(this.defaultHandler); + } + + + @Test + public void subProtocol() throws Exception { + + TestWebSocketSession session = new TestWebSocketSession(); + session.setAcceptedProtocol("sToMp"); + + this.multiProtocolHandler.afterConnectionEstablished(session); + + verify(this.stompHandler).afterConnectionEstablished(session); + verifyZeroInteractions(this.mqttHandler); + } + + @Test(expected=IllegalStateException.class) + public void subProtocolNoMatch() throws Exception { + + TestWebSocketSession session = new TestWebSocketSession(); + session.setAcceptedProtocol("wamp"); + + this.multiProtocolHandler.afterConnectionEstablished(session); + } + + @Test + public void noSubProtocol() throws Exception { + + TestWebSocketSession session = new TestWebSocketSession(); + + this.multiProtocolHandler.afterConnectionEstablished(session); + + verify(this.defaultHandler).afterConnectionEstablished(session); + verifyZeroInteractions(this.stompHandler, this.mqttHandler); + } + + @Test(expected=IllegalStateException.class) + public void noSubProtocolNoDefaultHandler() throws Exception { + + TestWebSocketSession session = new TestWebSocketSession(); + + this.multiProtocolHandler.setDefaultProtocolHandler(null); + this.multiProtocolHandler.afterConnectionEstablished(session); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java index a567c28573..5b6482a781 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java @@ -45,6 +45,8 @@ public class TestWebSocketSession implements WebSocketSession { private String remoteAddress; + private String protocol; + private boolean open; private final List> messages = new ArrayList<>(); @@ -142,6 +144,20 @@ public class TestWebSocketSession implements WebSocketSession { this.remoteAddress = remoteAddress; } + /** + * @return the subProtocol + */ + public String getAcceptedProtocol() { + return this.protocol; + } + + /** + * @param protocol the subProtocol to set + */ + public void setAcceptedProtocol(String protocol) { + this.protocol = protocol; + } + /** * @return the open */