Browse Source

Add support for WebSocket Protocol Extensions

This commits adds simple, overridable WebSocket Extension
filtering during the handshake phase and adds that
information in the WebSocket session.

The actual WebSocket Extension negotiation happens
within the server implementation (Glassfish, Jetty, Tomcat...),
so one can only remove requested extensions from
the list provided by the WebSocket client.

See RFC6455 Section 9.

Issue: SPR-10843
pull/395/merge
Brian Clozel 11 years ago committed by Rossen Stoyanchev
parent
commit
6d00a3f0ee
  1. 2
      spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java
  2. 167
      spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java
  3. 7
      spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java
  4. 18
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java
  5. 24
      spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java
  6. 34
      spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java
  7. 8
      spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java
  8. 3
      spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java
  9. 29
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java
  10. 11
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java
  11. 16
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java
  12. 18
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java
  13. 7
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java
  14. 8
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java
  15. 61
      spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java
  16. 17
      spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java
  17. 17
      spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java

2
spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java

@ -69,7 +69,7 @@ public interface SubProtocolHandler { @@ -69,7 +69,7 @@ public interface SubProtocolHandler {
/**
* Resolve the session id from the given message or return {@code null}.
*
* @param the message to resolve the session id from
* @param message the message to resolve the session id from
*/
String resolveSessionId(Message<?> message);

167
spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java

@ -0,0 +1,167 @@ @@ -0,0 +1,167 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedCaseInsensitiveMap;
import org.springframework.util.StringUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
/**
* WebSocket Protocol extension.
* Adds new protocol features to the WebSocket protocol; the extensions used
* within a session are negotiated during the handshake phase:
* <ul>
* <li>the client may ask for specific extensions in the HTTP request</li>
* <li>the server declares the final list of supported extensions for the current session in the HTTP response</li>
* </ul>
*
* <p>WebSocket Extension HTTP headers may include parameters and follow
* <a href="https://tools.ietf.org/html/rfc2616#section-4.2">RFC 2616 Section 4.2</a>
* specifications.</p>
*
* <p>Note that the order of extensions in HTTP headers defines their order of execution,
* e.g. extensions "foo, bar" will be executed as "bar(foo(message))".</p>
*
* @author Brian Clozel
* @since 4.0
* @see <a href="https://tools.ietf.org/html/rfc6455#section-9">
* WebSocket Protocol Extensions, RFC 6455 - Section 9</a>
*/
public class WebSocketExtension {
private final String name;
private final Map<String, String> parameters;
public WebSocketExtension(String name) {
this(name,null);
}
public WebSocketExtension(String name, Map<String, String> parameters) {
Assert.hasLength(name, "extension name must not be empty");
this.name = name;
if (!CollectionUtils.isEmpty(parameters)) {
Map<String, String> m = new LinkedCaseInsensitiveMap<String>(parameters.size(), Locale.ENGLISH);
m.putAll(parameters);
this.parameters = Collections.unmodifiableMap(m);
}
else {
this.parameters = Collections.emptyMap();
}
}
/**
* @return the name of the extension
*/
public String getName() {
return this.name;
}
/**
* @return the parameters of the extension
*/
public Map<String, String> getParameters() {
return this.parameters;
}
/**
* Parse a list of raw WebSocket extension headers
*/
public static List<WebSocketExtension> parseHeaders(List<String> headers) {
if (headers == null || headers.isEmpty()) {
return Collections.emptyList();
}
else {
List<WebSocketExtension> result = new ArrayList<WebSocketExtension>(headers.size());
for (String header : headers) {
result.addAll(parseHeader(header));
}
return result;
}
}
/**
* Parse a raw WebSocket extension header
*/
public static List<WebSocketExtension> parseHeader(String header) {
if (header == null || !StringUtils.hasText(header)) {
return Collections.emptyList();
}
else {
List<WebSocketExtension> result = new ArrayList<WebSocketExtension>();
for(String token : header.split(",")) {
result.add(parse(token));
}
return result;
}
}
private static WebSocketExtension parse(String extension) {
Assert.doesNotContain(extension,",","this string contains multiple extension declarations");
String[] parts = StringUtils.tokenizeToStringArray(extension, ";");
String name = parts[0].trim();
Map<String, String> parameters = null;
if (parts.length > 1) {
parameters = new LinkedHashMap<String, String>(parts.length - 1);
for (int i = 1; i < parts.length; i++) {
String parameter = parts[i];
int eqIndex = parameter.indexOf('=');
if (eqIndex != -1) {
String attribute = parameter.substring(0, eqIndex);
String value = parameter.substring(eqIndex + 1, parameter.length());
parameters.put(attribute, value);
}
}
}
return new WebSocketExtension(name,parameters);
}
/**
* Convert a list of WebSocketExtensions to a list of String,
* which is convenient for native HTTP headers.
*/
public static List<String> toStringList(List<WebSocketExtension> extensions) {
List<String> result = new ArrayList<String>(extensions.size());
for(WebSocketExtension extension : extensions) {
result.add(extension.toString());
}
return result;
}
@Override
public String toString() {
StringBuilder str = new StringBuilder();
str.append(this.name);
for (String param : parameters.keySet()) {
str.append(';');
str.append(param);
str.append('=');
str.append(this.parameters.get(param));
}
return str.toString();
}
}

7
spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java

@ -20,6 +20,7 @@ import java.io.IOException; @@ -20,6 +20,7 @@ import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.util.List;
import java.util.Map;
import org.springframework.http.HttpHeaders;
@ -79,6 +80,12 @@ public interface WebSocketSession { @@ -79,6 +80,12 @@ public interface WebSocketSession {
*/
String getAcceptedProtocol();
/**
* Return the negotiated extensions or {@code null} if none was specified or
* negotiated successfully.
*/
List<WebSocketExtension> getExtensions();
/**
* Return whether the connection is still open.
*/

18
spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java

@ -20,8 +20,11 @@ import java.io.IOException; @@ -20,8 +20,11 @@ import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig;
import org.springframework.http.HttpHeaders;
import org.springframework.util.ObjectUtils;
import org.springframework.web.socket.BinaryMessage;
@ -29,6 +32,7 @@ import org.springframework.web.socket.CloseStatus; @@ -29,6 +32,7 @@ import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PingMessage;
import org.springframework.web.socket.PongMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketSession;
/**
@ -42,6 +46,8 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion<org.eclipse @@ -42,6 +46,8 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion<org.eclipse
private HttpHeaders headers;
private List<WebSocketExtension> extensions;
private final Principal principal;
@ -104,6 +110,18 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion<org.eclipse @@ -104,6 +110,18 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion<org.eclipse
return getNativeSession().getUpgradeResponse().getAcceptedSubProtocol();
}
@Override
public List<WebSocketExtension> getExtensions() {
checkNativeSessionInitialized();
if(this.extensions == null) {
this.extensions = new ArrayList<WebSocketExtension>();
for(ExtensionConfig ext : getNativeSession().getUpgradeResponse().getExtensions()) {
this.extensions.add(new WebSocketExtension(ext.getName(),ext.getParameters()));
}
}
return this.extensions;
}
@Override
public boolean isOpen() {
return ((getNativeSession() != null) && getNativeSession().isOpen());

24
spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java

@ -20,10 +20,14 @@ import java.io.IOException; @@ -20,10 +20,14 @@ import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;
import javax.websocket.Extension;
import org.springframework.http.HttpHeaders;
import org.springframework.util.StringUtils;
@ -32,6 +36,7 @@ import org.springframework.web.socket.CloseStatus; @@ -32,6 +36,7 @@ import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PingMessage;
import org.springframework.web.socket.PongMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketSession;
/**
@ -48,6 +53,8 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion<javax.we @@ -48,6 +53,8 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion<javax.we
private final InetSocketAddress remoteAddress;
private List<WebSocketExtension> extensions;
/**
* Class constructor.
@ -108,6 +115,23 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion<javax.we @@ -108,6 +115,23 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion<javax.we
return StringUtils.isEmpty(protocol)? null : protocol;
}
@Override
public List<WebSocketExtension> getExtensions() {
checkNativeSessionInitialized();
if(this.extensions == null) {
List<Extension> nativeExtensions = getNativeSession().getNegotiatedExtensions();
this.extensions = new ArrayList<WebSocketExtension>(nativeExtensions.size());
for(Extension nativeExtension : nativeExtensions) {
Map<String, String> parameters = new HashMap<String, String>();
for (Extension.Parameter param : nativeExtension.getParameters()) {
parameters.put(param.getName(),param.getValue());
}
this.extensions.add(new WebSocketExtension(nativeExtension.getName(),parameters));
}
}
return this.extensions;
}
@Override
public boolean isOpen() {
return ((getNativeSession() != null) && getNativeSession().isOpen());

34
spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java

@ -32,13 +32,15 @@ import org.springframework.http.server.ServerHttpRequest; @@ -32,13 +32,15 @@ import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
/**
* A default {@link HandshakeHandler} implementation. Performs initial validation of the
* WebSocket handshake request -- possibly rejecting it through the appropriate HTTP
* status code -- while also allowing sub-classes to override various parts of the
* negotiation process (e.g. origin validation, sub-protocol negotiation, etc).
* negotiation process (e.g. origin validation, sub-protocol negotiation,
* extensions negotiation, etc).
*
* <p>
* If the negotiation succeeds, the actual upgrade is delegated to a server-specific
@ -188,6 +190,13 @@ public class DefaultHandshakeHandler implements HandshakeHandler { @@ -188,6 +190,13 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
logger.debug("Upgrading request, sub-protocol=" + subProtocol);
}
List<WebSocketExtension> requestedExtensions = WebSocketExtension
.parseHeaders(request.getHeaders().getSecWebSocketExtensions());
List<WebSocketExtension> filteredExtensions = filterRequestedExtensions(requestedExtensions,
this.requestUpgradeStrategy.getAvailableExtensions(request));
request.getHeaders().setSecWebSocketExtensions(WebSocketExtension.toStringList(filteredExtensions));
this.requestUpgradeStrategy.upgrade(request, response, subProtocol, wsHandler, attributes);
return true;
@ -254,4 +263,27 @@ public class DefaultHandshakeHandler implements HandshakeHandler { @@ -254,4 +263,27 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
return null;
}
/**
* Filter the list of WebSocket Extensions requested by the client.
* Since the negotiation process happens during the upgrade phase within the server
* implementation, one can customize the applied extensions only by filtering the
* requested extensions by the client.
*
* <p>The default implementation of this method doesn't filter any of the extensions
* requested by the client.
* @param requestedExtensions the list of extensions requested by the client
* @param supportedExtensions the list of extensions supported by the server
* @return the filtered list of requested extensions
*/
protected List<WebSocketExtension> filterRequestedExtensions(List<WebSocketExtension> requestedExtensions,
List<WebSocketExtension> supportedExtensions) {
if (requestedExtensions != null) {
if (logger.isDebugEnabled()) {
logger.debug("Requested extension(s): " + requestedExtensions
+ ", supported extension(s): " + supportedExtensions);
}
}
return requestedExtensions;
}
}

8
spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java

@ -16,10 +16,12 @@ @@ -16,10 +16,12 @@
package org.springframework.web.socket.server;
import java.util.List;
import java.util.Map;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
/**
@ -36,6 +38,12 @@ public interface RequestUpgradeStrategy { @@ -36,6 +38,12 @@ public interface RequestUpgradeStrategy {
*/
String[] getSupportedVersions();
/**
* @return the list of available WebSocket protocol extensions,
* implemented by the underlying WebSocket server.
*/
List<WebSocketExtension> getAvailableExtensions(ServerHttpRequest request);
/**
* Perform runtime specific steps to complete the upgrade. Invoked after successful
* negotiation of the handshake request.

3
spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java

@ -16,8 +16,7 @@ @@ -16,8 +16,7 @@
package org.springframework.web.socket.server.config;
import org.eclipse.jetty.websocket.server.WebSocketHandler;
import org.springframework.web.socket.WebSocketHandler;
/**

29
spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java

@ -19,15 +19,21 @@ package org.springframework.web.socket.server.support; @@ -19,15 +19,21 @@ package org.springframework.web.socket.server.support;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.Extension;
import javax.websocket.server.ServerContainer;
import org.apache.tomcat.websocket.server.WsServerContainer;
import org.glassfish.tyrus.core.ComponentProviderService;
import org.glassfish.tyrus.core.EndpointWrapper;
import org.glassfish.tyrus.core.ErrorCollector;
@ -47,6 +53,7 @@ import org.springframework.util.Assert; @@ -47,6 +53,7 @@ import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration;
import org.springframework.web.socket.server.endpoint.ServletServerContainerFactoryBean;
@ -66,11 +73,26 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt @@ -66,11 +73,26 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt
private final static Random random = new Random();
private List<WebSocketExtension> availableExtensions;
@Override
public String[] getSupportedVersions() {
return StringUtils.commaDelimitedListToStringArray(Version.getSupportedWireProtocolVersions());
}
@Override
public List<WebSocketExtension> getAvailableExtensions(ServerHttpRequest request) {
if(this.availableExtensions == null) {
this.availableExtensions = new ArrayList<WebSocketExtension>();
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
for(Extension extension : getContainer(servletRequest).getInstalledExtensions()) {
this.availableExtensions.add(parseStandardExtension(extension));
}
}
return this.availableExtensions;
}
@Override
public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
String selectedProtocol, Endpoint endpoint) throws HandshakeFailureException {
@ -103,6 +125,13 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt @@ -103,6 +125,13 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt
}
}
public ServerContainer getContainer(HttpServletRequest servletRequest) {
String attributeName = "javax.websocket.server.ServerContainer";
ServletContext servletContext = servletRequest.getServletContext();
return (ServerContainer)servletContext.getAttribute(attributeName);
}
private boolean performUpgrade(HttpServletRequest request, HttpServletResponse response,
HttpHeaders headers, WebSocketApplication wsApp) throws IOException {

11
spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java

@ -17,15 +17,18 @@ @@ -17,15 +17,18 @@
package org.springframework.web.socket.server.support;
import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.Map;
import javax.websocket.Endpoint;
import javax.websocket.Extension;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.StandardWebSocketSession;
@ -62,4 +65,12 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS @@ -62,4 +65,12 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS
protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
String selectedProtocol, Endpoint endpoint) throws HandshakeFailureException;
protected WebSocketExtension parseStandardExtension(Extension extension) {
Map<String, String> params = new HashMap<String,String>(extension.getParameters().size());
for(Extension.Parameter param : extension.getParameters()) {
params.put(param.getName(),param.getValue());
}
return new WebSocketExtension(extension.getName(),params);
}
}

16
spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java

@ -17,6 +17,8 @@ @@ -17,6 +17,8 @@
package org.springframework.web.socket.server.support;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
@ -34,6 +36,7 @@ import org.springframework.http.server.ServerHttpResponse; @@ -34,6 +36,7 @@ import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.JettyWebSocketSession;
@ -54,6 +57,8 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { @@ -54,6 +57,8 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy {
private WebSocketServerFactory factory;
private List<WebSocketExtension> availableExtensions;
/**
* Default constructor that creates {@link WebSocketServerFactory} through its default
@ -92,6 +97,17 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { @@ -92,6 +97,17 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy {
return new String[] { String.valueOf(HandshakeRFC6455.VERSION) };
}
@Override
public List<WebSocketExtension> getAvailableExtensions(ServerHttpRequest request) {
if(this.availableExtensions == null) {
this.availableExtensions = new ArrayList<WebSocketExtension>();
for(String extensionName : this.factory.getExtensionFactory().getExtensionNames()) {
this.availableExtensions.add(new WebSocketExtension(extensionName));
}
}
return this.availableExtensions;
}
@Override
public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
String protocol, WebSocketHandler wsHandler, Map<String, Object> attrs) throws HandshakeFailureException {

18
spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java

@ -17,8 +17,10 @@ @@ -17,8 +17,10 @@
package org.springframework.web.socket.server.support;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import javax.servlet.ServletContext;
@ -26,6 +28,7 @@ import javax.servlet.ServletException; @@ -26,6 +28,7 @@ import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.Endpoint;
import javax.websocket.Extension;
import org.apache.tomcat.websocket.server.WsServerContainer;
import org.springframework.http.server.ServerHttpRequest;
@ -33,6 +36,7 @@ import org.springframework.http.server.ServerHttpResponse; @@ -33,6 +36,7 @@ import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration;
import org.springframework.web.socket.server.endpoint.ServletServerContainerFactoryBean;
@ -50,12 +54,26 @@ import org.springframework.web.socket.server.endpoint.ServletServerContainerFact @@ -50,12 +54,26 @@ import org.springframework.web.socket.server.endpoint.ServletServerContainerFact
*/
public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {
private List<WebSocketExtension> availableExtensions;
@Override
public String[] getSupportedVersions() {
return new String[] { "13" };
}
@Override
public List<WebSocketExtension> getAvailableExtensions(ServerHttpRequest request) {
if(this.availableExtensions == null) {
this.availableExtensions = new ArrayList<WebSocketExtension>();
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
for(Extension extension : getContainer(servletRequest).getInstalledExtensions()) {
this.availableExtensions.add(parseStandardExtension(extension));
}
}
return this.availableExtensions;
}
@Override
public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
String acceptedProtocol, Endpoint endpoint) throws HandshakeFailureException {

7
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java

@ -20,6 +20,7 @@ import java.io.IOException; @@ -20,6 +20,7 @@ import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
@ -30,6 +31,7 @@ import org.springframework.http.server.ServerHttpRequest; @@ -30,6 +31,7 @@ import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.sockjs.SockJsException;
import org.springframework.web.socket.sockjs.SockJsTransportFailureException;
@ -66,6 +68,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { @@ -66,6 +68,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession {
private String acceptedProtocol;
private List<WebSocketExtension> extensions;
public AbstractHttpSockJsSession(String id, SockJsServiceConfig config,
WebSocketHandler wsHandler, Map<String, Object> handshakeAttributes) {
@ -116,6 +119,9 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { @@ -116,6 +119,9 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession {
this.remoteAddress = remoteAddress;
}
@Override
public List<WebSocketExtension> getExtensions() { return this.extensions; }
/**
* Unlike WebSocket where sub-protocol negotiation is part of the
* initial handshake, in HTTP transports the same negotiation must
@ -152,6 +158,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { @@ -152,6 +158,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession {
this.principal = request.getPrincipal();
this.localAddress = request.getLocalAddress();
this.remoteAddress = request.getRemoteAddress();
this.extensions = WebSocketExtension.parseHeaders(response.getHeaders().getSecWebSocketExtensions());
try {
delegateConnectionEstablished();

8
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java

@ -20,6 +20,7 @@ import java.io.IOException; @@ -20,6 +20,7 @@ import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.util.List;
import java.util.Map;
import org.springframework.http.HttpHeaders;
@ -27,6 +28,7 @@ import org.springframework.util.Assert; @@ -27,6 +28,7 @@ import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.NativeWebSocketSession;
@ -89,6 +91,12 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession @@ -89,6 +91,12 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession
return this.wsSession.getAcceptedProtocol();
}
@Override
public List<WebSocketExtension> getExtensions() {
checkDelegateSessionInitialized();
return this.wsSession.getExtensions();
}
private void checkDelegateSessionInitialized() {
Assert.state(this.wsSession != null, "WebSocketSession not yet initialized");
}

61
spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java

@ -0,0 +1,61 @@ @@ -0,0 +1,61 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import org.hamcrest.Matchers;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
/**
* Test fixture for {@link WebSocketExtension}
* @author Brian Clozel
*/
public class WebSocketExtensionTests {
@Test
public void parseHeaderSingle() {
List<WebSocketExtension> extensions = WebSocketExtension.parseHeader("x-test-extension ; foo=bar");
assertThat(extensions, Matchers.hasSize(1));
WebSocketExtension extension = extensions.get(0);
assertEquals("x-test-extension", extension.getName());
assertEquals(1, extension.getParameters().size());
assertEquals("bar", extension.getParameters().get("foo"));
}
@Test
public void parseHeaderMultiple() {
List<WebSocketExtension> extensions = WebSocketExtension.parseHeader("x-foo-extension, x-bar-extension");
assertThat(extensions, Matchers.hasSize(2));
assertEquals("x-foo-extension", extensions.get(0).getName());
assertEquals("x-bar-extension", extensions.get(1).getName());
}
@Test
public void parseHeaders() {
List<String> extensions = new ArrayList<String>();
extensions.add("x-foo-extension, x-bar-extension");
extensions.add("x-test-extension");
List<WebSocketExtension> parsedExtensions = WebSocketExtension.parseHeaders(extensions);
assertThat(parsedExtensions, Matchers.hasSize(3));
}
}

17
spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java

@ -26,6 +26,7 @@ import java.util.Map; @@ -26,6 +26,7 @@ import java.util.Map;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.sockjs.support.frame.SockJsFrame;
@ -58,6 +59,8 @@ public class TestSockJsSession extends AbstractSockJsSession { @@ -58,6 +59,8 @@ public class TestSockJsSession extends AbstractSockJsSession {
private String subProtocol;
private List<WebSocketExtension> extensions = new ArrayList<WebSocketExtension>();
public TestSockJsSession(String sessionId, SockJsServiceConfig config,
WebSocketHandler wsHandler, Map<String, Object> attributes) {
@ -118,7 +121,7 @@ public class TestSockJsSession extends AbstractSockJsSession { @@ -118,7 +121,7 @@ public class TestSockJsSession extends AbstractSockJsSession {
}
/**
* @param remoteAddress the remoteAddress to set
* @param localAddress the remoteAddress to set
*/
public void setLocalAddress(InetSocketAddress localAddress) {
this.localAddress = localAddress;
@ -148,6 +151,18 @@ public class TestSockJsSession extends AbstractSockJsSession { @@ -148,6 +151,18 @@ public class TestSockJsSession extends AbstractSockJsSession {
this.subProtocol = protocol;
}
/**
* @return the extensions
*/
@Override
public List<WebSocketExtension> getExtensions() { return this.extensions; }
/**
*
* @param extensions the extensions to set
*/
public void setExtensions(List<WebSocketExtension> extensions) { this.extensions = extensions; }
public CloseStatus getCloseStatus() {
return this.closeStatus;
}

17
spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java

@ -27,6 +27,7 @@ import java.util.Map; @@ -27,6 +27,7 @@ import java.util.Map;
import org.springframework.http.HttpHeaders;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
@ -51,6 +52,8 @@ public class TestWebSocketSession implements WebSocketSession { @@ -51,6 +52,8 @@ public class TestWebSocketSession implements WebSocketSession {
private String protocol;
private List<WebSocketExtension> extensions = new ArrayList<WebSocketExtension>();
private boolean open;
private final List<WebSocketMessage<?>> messages = new ArrayList<>();
@ -149,7 +152,7 @@ public class TestWebSocketSession implements WebSocketSession { @@ -149,7 +152,7 @@ public class TestWebSocketSession implements WebSocketSession {
}
/**
* @param remoteAddress the remoteAddress to set
* @param localAddress the remoteAddress to set
*/
public void setLocalAddress(InetSocketAddress localAddress) {
this.localAddress = localAddress;
@ -184,6 +187,18 @@ public class TestWebSocketSession implements WebSocketSession { @@ -184,6 +187,18 @@ public class TestWebSocketSession implements WebSocketSession {
this.protocol = protocol;
}
/**
* @return the extensions
*/
@Override
public List<WebSocketExtension> getExtensions() { return this.extensions; }
/**
*
* @param extensions the extensions to set
*/
public void setExtensions(List<WebSocketExtension> extensions) { this.extensions = extensions; }
/**
* @return the open
*/

Loading…
Cancel
Save