Browse Source

Add MultiProtocolWebSocketHandler

It makes it possible to deploy multiple WebSocketHandler's to a URL,
each supporting a different sub-protocol.

Issue: SPR-10786
pull/325/head
Rossen Stoyanchev 11 years ago
parent
commit
5a0e42b76e
  1. 5
      spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java
  2. 9
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/ConfigurableWebSocketSession.java
  3. 20
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSessionAdapter.java
  4. 19
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSessionAdapter.java
  5. 16
      spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java
  6. 4
      spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/EndpointConnectionManager.java
  7. 26
      spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java
  8. 8
      spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/ServerEndpointRegistration.java
  9. 2
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java
  10. 4
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java
  11. 5
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/ServerWebSocketSessionInitializer.java
  12. 3
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/DefaultSockJsService.java
  13. 3
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpReceivingTransportHandler.java
  14. 20
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpSockJsSession.java
  15. 14
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/WebSocketServerSockJsSession.java
  16. 125
      spring-websocket/src/main/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandler.java
  17. 2
      spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java
  18. 71
      spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java
  19. 12
      spring-websocket/src/test/java/org/springframework/web/socket/sockjs/TestSockJsSession.java
  20. 106
      spring-websocket/src/test/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandlerTests.java
  21. 16
      spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java

5
spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java

@ -61,6 +61,11 @@ public interface WebSocketSession { @@ -61,6 +61,11 @@ public interface WebSocketSession {
*/
String getRemoteAddress();
/**
* Return the negotiated sub-protocol or {@code null} if none was specified.
*/
String getAcceptedProtocol();
/**
* Return whether the connection is still open.
*/

9
spring-websocket/src/main/java/org/springframework/web/socket/adapter/ConfigurableWebSocketSession.java

@ -20,6 +20,7 @@ import java.net.URI; @@ -20,6 +20,7 @@ import java.net.URI;
import java.security.Principal;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.server.DefaultHandshakeHandler;
/**
* A WebSocketSession with configurable properties.
@ -37,4 +38,12 @@ public interface ConfigurableWebSocketSession extends WebSocketSession { @@ -37,4 +38,12 @@ public interface ConfigurableWebSocketSession extends WebSocketSession {
void setPrincipal(Principal principal);
/**
* Set the protocol accepted as part of the WebSocket handshake. This property can be
* used when the WebSocket handshake is performed through
* {@link DefaultHandshakeHandler} rather than the underlying WebSocket runtime, or
* when there is no WebSocket handshake (e.g. SockJS HTTP fallback options)
*/
void setAcceptedProtocol(String protocol);
}

20
spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSessionAdapter.java

@ -22,6 +22,7 @@ import java.net.URI; @@ -22,6 +22,7 @@ import java.net.URI;
import java.security.Principal;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.web.socket.BinaryMessage;
@ -44,11 +45,20 @@ public class JettyWebSocketSessionAdapter @@ -44,11 +45,20 @@ public class JettyWebSocketSessionAdapter
private Principal principal;
private String protocol;
@Override
public void initSession(Session session) {
Assert.notNull(session, "session must not be null");
this.session = session;
if (this.protocol == null) {
UpgradeResponse response = session.getUpgradeResponse();
if ((response != null) && response.getAcceptedSubProtocol() != null) {
this.protocol = response.getAcceptedSubProtocol();
}
}
}
@Override
@ -101,6 +111,16 @@ public class JettyWebSocketSessionAdapter @@ -101,6 +111,16 @@ public class JettyWebSocketSessionAdapter
// ignore
}
@Override
public String getAcceptedProtocol() {
return this.protocol;
}
@Override
public void setAcceptedProtocol(String protocol) {
this.protocol = protocol;
}
@Override
public boolean isOpen() {
return this.session.isOpen();

19
spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSessionAdapter.java

@ -24,6 +24,7 @@ import javax.websocket.CloseReason; @@ -24,6 +24,7 @@ import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
@ -45,11 +46,19 @@ public class StandardWebSocketSessionAdapter extends AbstractWebSocketSesssionAd @@ -45,11 +46,19 @@ public class StandardWebSocketSessionAdapter extends AbstractWebSocketSesssionAd
private String remoteAddress;
private String protocol;
@Override
public void initSession(javax.websocket.Session session) {
Assert.notNull(session, "session must not be null");
this.session = session;
if (this.protocol == null) {
if (StringUtils.hasText(session.getNegotiatedSubprotocol())) {
this.protocol = session.getNegotiatedSubprotocol();
}
}
}
@Override
@ -103,6 +112,16 @@ public class StandardWebSocketSessionAdapter extends AbstractWebSocketSesssionAd @@ -103,6 +112,16 @@ public class StandardWebSocketSessionAdapter extends AbstractWebSocketSesssionAd
this.remoteAddress = address;
}
@Override
public String getAcceptedProtocol() {
return this.protocol;
}
@Override
public void setAcceptedProtocol(String protocol) {
this.protocol = protocol;
}
@Override
public boolean isOpen() {
return this.session.isOpen();

16
spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java

@ -43,7 +43,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport { @@ -43,7 +43,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
private WebSocketSession webSocketSession;
private final List<String> subProtocols = new ArrayList<String>();
private final List<String> protocols = new ArrayList<String>();
private final boolean syncClientLifecycle;
@ -67,15 +67,15 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport { @@ -67,15 +67,15 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
return new LoggingWebSocketHandlerDecorator(handler);
}
public void setSubProtocols(List<String> subProtocols) {
this.subProtocols.clear();
if (!CollectionUtils.isEmpty(subProtocols)) {
this.subProtocols.addAll(subProtocols);
public void setSupportedProtocols(List<String> protocols) {
this.protocols.clear();
if (!CollectionUtils.isEmpty(protocols)) {
this.protocols.addAll(protocols);
}
}
public List<String> getSubProtocols() {
return this.subProtocols;
public List<String> getSupportedProtocols() {
return this.protocols;
}
@Override
@ -97,7 +97,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport { @@ -97,7 +97,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
@Override
protected void openConnection() throws Exception {
HttpHeaders headers = new HttpHeaders();
headers.setSecWebSocketProtocol(this.subProtocols);
headers.setSecWebSocketProtocol(this.protocols);
this.webSocketSession = this.client.doHandshake(this.webSocketHandler, headers, getUri());
}

4
spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/EndpointConnectionManager.java

@ -74,8 +74,8 @@ public class EndpointConnectionManager extends ConnectionManagerSupport implemen @@ -74,8 +74,8 @@ public class EndpointConnectionManager extends ConnectionManagerSupport implemen
}
public void setSubProtocols(String... subprotocols) {
this.configBuilder.preferredSubprotocols(Arrays.asList(subprotocols));
public void setSupportedProtocols(String... protocols) {
this.configBuilder.preferredSubprotocols(Arrays.asList(protocols));
}
public void setExtensions(Extension... extensions) {

26
spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java

@ -21,7 +21,6 @@ import java.nio.charset.Charset; @@ -21,7 +21,6 @@ import java.nio.charset.Charset;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@ -35,7 +34,6 @@ import org.springframework.http.HttpStatus; @@ -35,7 +34,6 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
@ -55,7 +53,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { @@ -55,7 +53,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
protected Log logger = LogFactory.getLog(getClass());
private List<String> supportedProtocols = new ArrayList<String>();
private final List<String> supportedProtocols = new ArrayList<String>();
private final RequestUpgradeStrategy requestUpgradeStrategy;
@ -78,11 +76,22 @@ public class DefaultHandshakeHandler implements HandshakeHandler { @@ -78,11 +76,22 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
this.requestUpgradeStrategy = upgradeStrategy;
}
/**
* Use this property to configure a list of sub-protocols that are supported.
* The first protocol that matches what the client requested is selected.
* If no protocol matches or this property is not configured, then the
* response will not contain a Sec-WebSocket-Protocol header.
*/
public void setSupportedProtocols(String... protocols) {
this.supportedProtocols = Arrays.asList(protocols);
this.supportedProtocols.clear();
for (String protocol : protocols) {
this.supportedProtocols.add(protocol.toLowerCase());
}
}
/**
* Return the list of supported sub-protocols.
*/
public String[] getSupportedProtocols() {
return this.supportedProtocols.toArray(new String[this.supportedProtocols.size()]);
}
@ -191,9 +200,12 @@ public class DefaultHandshakeHandler implements HandshakeHandler { @@ -191,9 +200,12 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
}
protected String selectProtocol(List<String> requestedProtocols) {
if (CollectionUtils.isEmpty(requestedProtocols)) {
if (requestedProtocols != null) {
for (String protocol : requestedProtocols) {
if (this.supportedProtocols.contains(protocol)) {
if (this.supportedProtocols.contains(protocol.toLowerCase())) {
if (logger.isDebugEnabled()) {
logger.debug("Selected sub-protocol '" + protocol + "'");
}
return protocol;
}
}

8
spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/ServerEndpointRegistration.java

@ -60,7 +60,7 @@ public class ServerEndpointRegistration implements ServerEndpointConfig, BeanFac @@ -60,7 +60,7 @@ public class ServerEndpointRegistration implements ServerEndpointConfig, BeanFac
private List<Class<? extends Decoder>> decoders = new ArrayList<Class<? extends Decoder>>();
private List<String> subprotocols = new ArrayList<String>();
private List<String> protocols = new ArrayList<String>();
private List<Extension> extensions = new ArrayList<Extension>();
@ -113,13 +113,13 @@ public class ServerEndpointRegistration implements ServerEndpointConfig, BeanFac @@ -113,13 +113,13 @@ public class ServerEndpointRegistration implements ServerEndpointConfig, BeanFac
return (this.endpoint != null) ? this.endpoint : this.endpointProvider.getHandler();
}
public void setSubprotocols(List<String> subprotocols) {
this.subprotocols = subprotocols;
public void setSubprotocols(List<String> protocols) {
this.protocols = protocols;
}
@Override
public List<String> getSubprotocols() {
return this.subprotocols;
return this.protocols;
}
public void setExtensions(List<Extension> extensions) {

2
spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java

@ -48,7 +48,7 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS @@ -48,7 +48,7 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS
String protocol, WebSocketHandler handler) throws IOException, HandshakeFailureException {
StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter();
this.wsSessionInitializer.initialize(request, response, session);
this.wsSessionInitializer.initialize(request, response, protocol, session);
StandardEndpointAdapter endpoint = new StandardEndpointAdapter(handler, session);
upgradeInternal(request, response, protocol, endpoint);
}

4
spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java

@ -87,7 +87,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { @@ -87,7 +87,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy {
@Override
public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
String selectedProtocol, WebSocketHandler webSocketHandler) throws IOException {
String protocol, WebSocketHandler webSocketHandler) throws IOException {
Assert.isInstanceOf(ServletServerHttpRequest.class, request);
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
@ -101,7 +101,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { @@ -101,7 +101,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy {
}
JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter();
this.wsSessionInitializer.initialize(request, response, session);
this.wsSessionInitializer.initialize(request, response, protocol, session);
JettyWebSocketListenerAdapter listener = new JettyWebSocketListenerAdapter(webSocketHandler, session);
servletRequest.setAttribute(WEBSOCKET_LISTENER_ATTR_NAME, listener);

5
spring-websocket/src/main/java/org/springframework/web/socket/server/support/ServerWebSocketSessionInitializer.java

@ -30,11 +30,14 @@ import org.springframework.web.socket.adapter.ConfigurableWebSocketSession; @@ -30,11 +30,14 @@ import org.springframework.web.socket.adapter.ConfigurableWebSocketSession;
*/
public class ServerWebSocketSessionInitializer {
public void initialize(ServerHttpRequest request, ServerHttpResponse response, ConfigurableWebSocketSession session) {
public void initialize(ServerHttpRequest request, ServerHttpResponse response,
String protocol, ConfigurableWebSocketSession session) {
session.setUri(request.getURI());
session.setRemoteHostName(request.getRemoteHostName());
session.setRemoteAddress(request.getRemoteAddress());
session.setPrincipal(request.getPrincipal());
session.setAcceptedProtocol(protocol);
}
}

3
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/DefaultSockJsService.java

@ -245,7 +245,8 @@ public class DefaultSockJsService extends AbstractSockJsService { @@ -245,7 +245,8 @@ public class DefaultSockJsService extends AbstractSockJsService {
}
logger.debug("Creating new session with session id \"" + sessionId + "\"");
session = sessionFactory.createSession(sessionId, handler);
this.sessionInitializer.initialize(request, response, session);
String protocol = null; // TODO: https://github.com/sockjs/sockjs-client/issues/130
this.sessionInitializer.initialize(request, response, protocol, session);
this.sessions.put(sessionId, session);
return session;
}

3
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpReceivingTransportHandler.java

@ -63,6 +63,9 @@ public abstract class AbstractHttpReceivingTransportHandler implements Transport @@ -63,6 +63,9 @@ public abstract class AbstractHttpReceivingTransportHandler implements Transport
return;
}
// TODO: check "Sec-WebSocket-Protocol" header
// https://github.com/sockjs/sockjs-client/issues/130
handleRequestInternal(request, response, session);
}

20
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpSockJsSession.java

@ -49,12 +49,32 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { @@ -49,12 +49,32 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession {
private ServerHttpResponse response;
private String protocol;
public AbstractHttpSockJsSession(String sessionId, SockJsConfiguration config, WebSocketHandler handler) {
super(sessionId, config, handler);
}
/**
* Unlike WebSocket where sub-protocol negotiation is part of the
* initial handshake, in HTTP transports the same negotiation must
* be emulated and the selected protocol set through this setter.
*
* @param protocol the sub-protocol to set
*/
public void setAcceptedProtocol(String protocol) {
this.protocol = protocol;
}
/**
* Return the selected sub-protocol to use.
*/
public String getAcceptedProtocol() {
return this.protocol;
}
public synchronized void setInitialRequest(ServerHttpRequest request, ServerHttpResponse response,
FrameFormat frameFormat) throws TransportErrorException {

14
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/WebSocketServerSockJsSession.java

@ -48,6 +48,20 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession { @@ -48,6 +48,20 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession {
}
@Override
public String getAcceptedProtocol() {
if (this.webSocketSession == null) {
logger.warn("getAcceptedProtocol() invoked before WebSocketSession has been initialized.");
return null;
}
return this.webSocketSession.getAcceptedProtocol();
}
@Override
public void setAcceptedProtocol(String protocol) {
// ignore, webSocketSession should have it
}
public void initWebSocketSession(WebSocketSession session) throws Exception {
this.webSocketSession = session;
try {

125
spring-websocket/src/main/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandler.java

@ -0,0 +1,125 @@ @@ -0,0 +1,125 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.support;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* A {@link WebSocketHandler} that delegates to other {@link WebSocketHandler} instances
* based on the sub-protocol value accepted at the handshake. A default handler can also
* be configured for use by default when a sub-protocol value if the WebSocket session
* does not have a sub-protocol value associated with it.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class MultiProtocolWebSocketHandler implements WebSocketHandler {
private WebSocketHandler defaultHandler;
private Map<String, WebSocketHandler> handlers = new HashMap<String, WebSocketHandler>();
/**
* Configure {@link WebSocketHandler}'s to use by sub-protocol. The values for
* sub-protocols are case insensitive.
*/
public void setProtocolHandlers(Map<String, WebSocketHandler> protocolHandlers) {
this.handlers.clear();
for (String protocol : protocolHandlers.keySet()) {
this.handlers.put(protocol.toLowerCase(), protocolHandlers.get(protocol));
}
}
/**
* Return a read-only copy of the sub-protocol handler map.
*/
public Map<String, WebSocketHandler> getProtocolHandlers() {
return Collections.unmodifiableMap(this.handlers);
}
/**
* Set the default {@link WebSocketHandler} to use if a sub-protocol was not
* requested.
*/
public void setDefaultProtocolHandler(WebSocketHandler defaultHandler) {
this.defaultHandler = defaultHandler;
}
/**
* Return the default {@link WebSocketHandler} to be used.
*/
public WebSocketHandler getDefaultProtocolHandler() {
return this.defaultHandler;
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
WebSocketHandler handler = getHandlerForSession(session);
handler.afterConnectionEstablished(session);
}
private WebSocketHandler getHandlerForSession(WebSocketSession session) {
WebSocketHandler handler = null;
String protocol = session.getAcceptedProtocol();
if (protocol != null) {
handler = this.handlers.get(protocol.toLowerCase());
Assert.state(handler != null,
"No WebSocketHandler for sub-protocol '" + protocol + "', handlers=" + this.handlers);
}
else {
handler = this.defaultHandler;
Assert.state(handler != null,
"No sub-protocol was requested and no default WebSocketHandler was configured");
}
return handler;
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
WebSocketHandler handler = getHandlerForSession(session);
handler.handleMessage(session, message);
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
WebSocketHandler handler = getHandlerForSession(session);
handler.handleTransportError(session, exception);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
WebSocketHandler handler = getHandlerForSession(session);
handler.afterConnectionClosed(session, closeStatus);
}
@Override
public boolean supportsPartialMessages() {
return false;
}
}

2
spring-websocket/src/test/java/org/springframework/web/socket/client/WebSocketConnectionManagerTests.java

@ -51,7 +51,7 @@ public class WebSocketConnectionManagerTests { @@ -51,7 +51,7 @@ public class WebSocketConnectionManagerTests {
WebSocketHandler handler = new WebSocketHandlerAdapter();
WebSocketConnectionManager manager = new WebSocketConnectionManager(client, handler , "/path/{id}", "123");
manager.setSubProtocols(subprotocols);
manager.setSupportedProtocols(subprotocols);
manager.openConnection();
ArgumentCaptor<WebSocketHandlerDecorator> captor = ArgumentCaptor.forClass(WebSocketHandlerDecorator.class);

71
spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java

@ -0,0 +1,71 @@ @@ -0,0 +1,71 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.server;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.adapter.TextWebSocketHandlerAdapter;
import static org.mockito.Mockito.*;
/**
* Test fixture for {@link DefaultHandshakeHandler}.
*
* @author Rossen Stoyanchev
*/
public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
private DefaultHandshakeHandler handshakeHandler;
@Mock
private RequestUpgradeStrategy upgradeStrategy;
@Before
public void setup() throws Exception {
MockitoAnnotations.initMocks(this);
this.handshakeHandler = new DefaultHandshakeHandler(this.upgradeStrategy);
}
@Test
public void selectSubProtocol() throws Exception {
this.handshakeHandler.setSupportedProtocols("stomp", "mqtt");
when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] { "13" });
this.servletRequest.setMethod("GET");
this.request.getHeaders().setUpgrade("WebSocket");
this.request.getHeaders().setConnection("Upgrade");
this.request.getHeaders().setSecWebSocketVersion("13");
this.request.getHeaders().setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
this.request.getHeaders().setSecWebSocketProtocol("STOMP");
WebSocketHandler handler = new TextWebSocketHandlerAdapter();
this.handshakeHandler.doHandshake(this.request, this.response, handler);
verify(this.upgradeStrategy).upgrade(request, response, "STOMP", handler);
}
}

12
spring-websocket/src/test/java/org/springframework/web/socket/sockjs/TestSockJsSession.java

@ -40,11 +40,23 @@ public class TestSockJsSession extends AbstractSockJsSession { @@ -40,11 +40,23 @@ public class TestSockJsSession extends AbstractSockJsSession {
private boolean cancelledHeartbeat;
private String subProtocol;
public TestSockJsSession(String sessionId, SockJsConfiguration config, WebSocketHandler handler) {
super(sessionId, config, handler);
}
@Override
public String getAcceptedProtocol() {
return this.subProtocol;
}
@Override
public void setAcceptedProtocol(String protocol) {
this.subProtocol = protocol;
}
public CloseStatus getStatus() {
return this.status;
}

106
spring-websocket/src/test/java/org/springframework/web/socket/support/MultiProtocolWebSocketHandlerTests.java

@ -0,0 +1,106 @@ @@ -0,0 +1,106 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIOsNS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.support;
import java.util.HashMap;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.web.socket.WebSocketHandler;
import static org.mockito.Mockito.*;
/**
* Test fixture for {@link MultiProtocolWebSocketHandler}.
*
* @author Rossen Stoyanchev
*/
public class MultiProtocolWebSocketHandlerTests {
private MultiProtocolWebSocketHandler multiProtocolHandler;
@Mock
WebSocketHandler stompHandler;
@Mock
WebSocketHandler mqttHandler;
@Mock
WebSocketHandler defaultHandler;
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
Map<String, WebSocketHandler> handlers = new HashMap<String, WebSocketHandler>();
handlers.put("STOMP", this.stompHandler);
handlers.put("MQTT", this.mqttHandler);
this.multiProtocolHandler = new MultiProtocolWebSocketHandler();
this.multiProtocolHandler.setProtocolHandlers(handlers);
this.multiProtocolHandler.setDefaultProtocolHandler(this.defaultHandler);
}
@Test
public void subProtocol() throws Exception {
TestWebSocketSession session = new TestWebSocketSession();
session.setAcceptedProtocol("sToMp");
this.multiProtocolHandler.afterConnectionEstablished(session);
verify(this.stompHandler).afterConnectionEstablished(session);
verifyZeroInteractions(this.mqttHandler);
}
@Test(expected=IllegalStateException.class)
public void subProtocolNoMatch() throws Exception {
TestWebSocketSession session = new TestWebSocketSession();
session.setAcceptedProtocol("wamp");
this.multiProtocolHandler.afterConnectionEstablished(session);
}
@Test
public void noSubProtocol() throws Exception {
TestWebSocketSession session = new TestWebSocketSession();
this.multiProtocolHandler.afterConnectionEstablished(session);
verify(this.defaultHandler).afterConnectionEstablished(session);
verifyZeroInteractions(this.stompHandler, this.mqttHandler);
}
@Test(expected=IllegalStateException.class)
public void noSubProtocolNoDefaultHandler() throws Exception {
TestWebSocketSession session = new TestWebSocketSession();
this.multiProtocolHandler.setDefaultProtocolHandler(null);
this.multiProtocolHandler.afterConnectionEstablished(session);
}
}

16
spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java

@ -45,6 +45,8 @@ public class TestWebSocketSession implements WebSocketSession { @@ -45,6 +45,8 @@ public class TestWebSocketSession implements WebSocketSession {
private String remoteAddress;
private String protocol;
private boolean open;
private final List<WebSocketMessage<?>> messages = new ArrayList<>();
@ -142,6 +144,20 @@ public class TestWebSocketSession implements WebSocketSession { @@ -142,6 +144,20 @@ public class TestWebSocketSession implements WebSocketSession {
this.remoteAddress = remoteAddress;
}
/**
* @return the subProtocol
*/
public String getAcceptedProtocol() {
return this.protocol;
}
/**
* @param protocol the subProtocol to set
*/
public void setAcceptedProtocol(String protocol) {
this.protocol = protocol;
}
/**
* @return the open
*/

Loading…
Cancel
Save