From 6d00a3f0ee795e1a5bd9ffe02c584623828dbc37 Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Thu, 24 Oct 2013 11:28:21 +0200 Subject: [PATCH] Add support for WebSocket Protocol Extensions This commits adds simple, overridable WebSocket Extension filtering during the handshake phase and adds that information in the WebSocket session. The actual WebSocket Extension negotiation happens within the server implementation (Glassfish, Jetty, Tomcat...), so one can only remove requested extensions from the list provided by the WebSocket client. See RFC6455 Section 9. Issue: SPR-10843 --- .../handler/websocket/SubProtocolHandler.java | 2 +- .../web/socket/WebSocketExtension.java | 167 ++++++++++++++++++ .../web/socket/WebSocketSession.java | 7 + .../socket/adapter/JettyWebSocketSession.java | 18 ++ .../adapter/StandardWebSocketSession.java | 24 +++ .../server/DefaultHandshakeHandler.java | 34 +++- .../socket/server/RequestUpgradeStrategy.java | 8 + .../server/config/WebSocketConfigurer.java | 3 +- ...stractGlassFishRequestUpgradeStrategy.java | 29 +++ .../AbstractStandardUpgradeStrategy.java | 11 ++ .../support/JettyRequestUpgradeStrategy.java | 16 ++ .../support/TomcatRequestUpgradeStrategy.java | 18 ++ .../session/AbstractHttpSockJsSession.java | 7 + .../session/WebSocketServerSockJsSession.java | 8 + .../web/socket/WebSocketExtensionTests.java | 61 +++++++ .../transport/session/TestSockJsSession.java | 17 +- .../socket/support/TestWebSocketSession.java | 17 +- 17 files changed, 441 insertions(+), 6 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java index 1d3dcb30a0..fa780ed94d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java @@ -69,7 +69,7 @@ public interface SubProtocolHandler { /** * Resolve the session id from the given message or return {@code null}. * - * @param the message to resolve the session id from + * @param message the message to resolve the session id from */ String resolveSessionId(Message message); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java new file mode 100644 index 0000000000..399a3eca81 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java @@ -0,0 +1,167 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket; + +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.StringUtils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * WebSocket Protocol extension. + * Adds new protocol features to the WebSocket protocol; the extensions used + * within a session are negotiated during the handshake phase: + * + * + *

WebSocket Extension HTTP headers may include parameters and follow + * RFC 2616 Section 4.2 + * specifications.

+ * + *

Note that the order of extensions in HTTP headers defines their order of execution, + * e.g. extensions "foo, bar" will be executed as "bar(foo(message))".

+ * + * @author Brian Clozel + * @since 4.0 + * @see + * WebSocket Protocol Extensions, RFC 6455 - Section 9 + */ +public class WebSocketExtension { + + private final String name; + + private final Map parameters; + + public WebSocketExtension(String name) { + this(name,null); + } + + public WebSocketExtension(String name, Map parameters) { + Assert.hasLength(name, "extension name must not be empty"); + this.name = name; + if (!CollectionUtils.isEmpty(parameters)) { + Map m = new LinkedCaseInsensitiveMap(parameters.size(), Locale.ENGLISH); + m.putAll(parameters); + this.parameters = Collections.unmodifiableMap(m); + } + else { + this.parameters = Collections.emptyMap(); + } + } + + /** + * @return the name of the extension + */ + public String getName() { + return this.name; + } + + /** + * @return the parameters of the extension + */ + public Map getParameters() { + return this.parameters; + } + + /** + * Parse a list of raw WebSocket extension headers + */ + public static List parseHeaders(List headers) { + if (headers == null || headers.isEmpty()) { + return Collections.emptyList(); + } + else { + List result = new ArrayList(headers.size()); + for (String header : headers) { + result.addAll(parseHeader(header)); + } + return result; + } + } + + /** + * Parse a raw WebSocket extension header + */ + public static List parseHeader(String header) { + if (header == null || !StringUtils.hasText(header)) { + return Collections.emptyList(); + } + else { + List result = new ArrayList(); + for(String token : header.split(",")) { + result.add(parse(token)); + } + return result; + } + } + + private static WebSocketExtension parse(String extension) { + Assert.doesNotContain(extension,",","this string contains multiple extension declarations"); + String[] parts = StringUtils.tokenizeToStringArray(extension, ";"); + String name = parts[0].trim(); + + Map parameters = null; + if (parts.length > 1) { + parameters = new LinkedHashMap(parts.length - 1); + for (int i = 1; i < parts.length; i++) { + String parameter = parts[i]; + int eqIndex = parameter.indexOf('='); + if (eqIndex != -1) { + String attribute = parameter.substring(0, eqIndex); + String value = parameter.substring(eqIndex + 1, parameter.length()); + parameters.put(attribute, value); + } + } + } + + return new WebSocketExtension(name,parameters); + } + + /** + * Convert a list of WebSocketExtensions to a list of String, + * which is convenient for native HTTP headers. + */ + public static List toStringList(List extensions) { + List result = new ArrayList(extensions.size()); + for(WebSocketExtension extension : extensions) { + result.add(extension.toString()); + } + return result; + } + + @Override + public String toString() { + StringBuilder str = new StringBuilder(); + str.append(this.name); + for (String param : parameters.keySet()) { + str.append(';'); + str.append(param); + str.append('='); + str.append(this.parameters.get(param)); + } + return str.toString(); + } +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java index b935e21889..a4634ddc39 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.List; import java.util.Map; import org.springframework.http.HttpHeaders; @@ -79,6 +80,12 @@ public interface WebSocketSession { */ String getAcceptedProtocol(); + /** + * Return the negotiated extensions or {@code null} if none was specified or + * negotiated successfully. + */ + List getExtensions(); + /** * Return whether the connection is still open. */ diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java index b7bc0a8130..62222e394e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java @@ -20,8 +20,11 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.ArrayList; +import java.util.List; import java.util.Map; +import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig; import org.springframework.http.HttpHeaders; import org.springframework.util.ObjectUtils; import org.springframework.web.socket.BinaryMessage; @@ -29,6 +32,7 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.PongMessage; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketSession; /** @@ -42,6 +46,8 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion extensions; + private final Principal principal; @@ -104,6 +110,18 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion getExtensions() { + checkNativeSessionInitialized(); + if(this.extensions == null) { + this.extensions = new ArrayList(); + for(ExtensionConfig ext : getNativeSession().getUpgradeResponse().getExtensions()) { + this.extensions.add(new WebSocketExtension(ext.getName(),ext.getParameters())); + } + } + return this.extensions; + } + @Override public boolean isOpen() { return ((getNativeSession() != null) && getNativeSession().isOpen()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java index 7ddd3a8718..6a9fa9be8c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java @@ -20,10 +20,14 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import javax.websocket.CloseReason; import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.Extension; import org.springframework.http.HttpHeaders; import org.springframework.util.StringUtils; @@ -32,6 +36,7 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.PongMessage; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketSession; /** @@ -48,6 +53,8 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion extensions; + /** * Class constructor. @@ -108,6 +115,23 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion getExtensions() { + checkNativeSessionInitialized(); + if(this.extensions == null) { + List nativeExtensions = getNativeSession().getNegotiatedExtensions(); + this.extensions = new ArrayList(nativeExtensions.size()); + for(Extension nativeExtension : nativeExtensions) { + Map parameters = new HashMap(); + for (Extension.Parameter param : nativeExtension.getParameters()) { + parameters.put(param.getName(),param.getValue()); + } + this.extensions.add(new WebSocketExtension(nativeExtension.getName(),parameters)); + } + } + return this.extensions; + } + @Override public boolean isOpen() { return ((getNativeSession() != null) && getNativeSession().isOpen()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java index 6c0bfa0e84..f4b02b3a01 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java @@ -32,13 +32,15 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; /** * A default {@link HandshakeHandler} implementation. Performs initial validation of the * WebSocket handshake request -- possibly rejecting it through the appropriate HTTP * status code -- while also allowing sub-classes to override various parts of the - * negotiation process (e.g. origin validation, sub-protocol negotiation, etc). + * negotiation process (e.g. origin validation, sub-protocol negotiation, + * extensions negotiation, etc). * *

* If the negotiation succeeds, the actual upgrade is delegated to a server-specific @@ -188,6 +190,13 @@ public class DefaultHandshakeHandler implements HandshakeHandler { logger.debug("Upgrading request, sub-protocol=" + subProtocol); } + List requestedExtensions = WebSocketExtension + .parseHeaders(request.getHeaders().getSecWebSocketExtensions()); + + List filteredExtensions = filterRequestedExtensions(requestedExtensions, + this.requestUpgradeStrategy.getAvailableExtensions(request)); + request.getHeaders().setSecWebSocketExtensions(WebSocketExtension.toStringList(filteredExtensions)); + this.requestUpgradeStrategy.upgrade(request, response, subProtocol, wsHandler, attributes); return true; @@ -254,4 +263,27 @@ public class DefaultHandshakeHandler implements HandshakeHandler { return null; } + /** + * Filter the list of WebSocket Extensions requested by the client. + * Since the negotiation process happens during the upgrade phase within the server + * implementation, one can customize the applied extensions only by filtering the + * requested extensions by the client. + * + *

The default implementation of this method doesn't filter any of the extensions + * requested by the client. + * @param requestedExtensions the list of extensions requested by the client + * @param supportedExtensions the list of extensions supported by the server + * @return the filtered list of requested extensions + */ + protected List filterRequestedExtensions(List requestedExtensions, + List supportedExtensions) { + + if (requestedExtensions != null) { + if (logger.isDebugEnabled()) { + logger.debug("Requested extension(s): " + requestedExtensions + + ", supported extension(s): " + supportedExtensions); + } + } + return requestedExtensions; + } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java index b781426445..89b197fe38 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java @@ -16,10 +16,12 @@ package org.springframework.web.socket.server; +import java.util.List; import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; /** @@ -36,6 +38,12 @@ public interface RequestUpgradeStrategy { */ String[] getSupportedVersions(); + /** + * @return the list of available WebSocket protocol extensions, + * implemented by the underlying WebSocket server. + */ + List getAvailableExtensions(ServerHttpRequest request); + /** * Perform runtime specific steps to complete the upgrade. Invoked after successful * negotiation of the handshake request. diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java index 1d41f06bb1..221ac25634 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java @@ -16,8 +16,7 @@ package org.springframework.web.socket.server.config; -import org.eclipse.jetty.websocket.server.WebSocketHandler; - +import org.springframework.web.socket.WebSocketHandler; /** diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java index 6f0b783887..7349615f50 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java @@ -19,15 +19,21 @@ package org.springframework.web.socket.server.support; import java.io.IOException; import java.lang.reflect.Constructor; import java.net.URI; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Random; +import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.DeploymentException; import javax.websocket.Endpoint; +import javax.websocket.Extension; +import javax.websocket.server.ServerContainer; +import org.apache.tomcat.websocket.server.WsServerContainer; import org.glassfish.tyrus.core.ComponentProviderService; import org.glassfish.tyrus.core.EndpointWrapper; import org.glassfish.tyrus.core.ErrorCollector; @@ -47,6 +53,7 @@ import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration; import org.springframework.web.socket.server.endpoint.ServletServerContainerFactoryBean; @@ -66,11 +73,26 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt private final static Random random = new Random(); + private List availableExtensions; + @Override public String[] getSupportedVersions() { return StringUtils.commaDelimitedListToStringArray(Version.getSupportedWireProtocolVersions()); } + @Override + public List getAvailableExtensions(ServerHttpRequest request) { + + if(this.availableExtensions == null) { + this.availableExtensions = new ArrayList(); + HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); + for(Extension extension : getContainer(servletRequest).getInstalledExtensions()) { + this.availableExtensions.add(parseStandardExtension(extension)); + } + } + return this.availableExtensions; + } + @Override public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, Endpoint endpoint) throws HandshakeFailureException { @@ -103,6 +125,13 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt } } + public ServerContainer getContainer(HttpServletRequest servletRequest) { + + String attributeName = "javax.websocket.server.ServerContainer"; + ServletContext servletContext = servletRequest.getServletContext(); + return (ServerContainer)servletContext.getAttribute(attributeName); + } + private boolean performUpgrade(HttpServletRequest request, HttpServletResponse response, HttpHeaders headers, WebSocketApplication wsApp) throws IOException { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java index 925540c3fc..0f62add358 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java @@ -17,15 +17,18 @@ package org.springframework.web.socket.server.support; import java.net.InetSocketAddress; +import java.util.HashMap; import java.util.Map; import javax.websocket.Endpoint; +import javax.websocket.Extension; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.StandardWebSocketSession; @@ -62,4 +65,12 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, Endpoint endpoint) throws HandshakeFailureException; + protected WebSocketExtension parseStandardExtension(Extension extension) { + Map params = new HashMap(extension.getParameters().size()); + for(Extension.Parameter param : extension.getParameters()) { + params.put(param.getName(),param.getValue()); + } + return new WebSocketExtension(extension.getName(),params); + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java index 14f6d1ce6b..3064adb417 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java @@ -17,6 +17,8 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import javax.servlet.http.HttpServletRequest; @@ -34,6 +36,7 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSession; @@ -54,6 +57,8 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { private WebSocketServerFactory factory; + private List availableExtensions; + /** * Default constructor that creates {@link WebSocketServerFactory} through its default @@ -92,6 +97,17 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { return new String[] { String.valueOf(HandshakeRFC6455.VERSION) }; } + @Override + public List getAvailableExtensions(ServerHttpRequest request) { + if(this.availableExtensions == null) { + this.availableExtensions = new ArrayList(); + for(String extensionName : this.factory.getExtensionFactory().getExtensionNames()) { + this.availableExtensions.add(new WebSocketExtension(extensionName)); + } + } + return this.availableExtensions; + } + @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol, WebSocketHandler wsHandler, Map attrs) throws HandshakeFailureException { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java index 4a96dff264..7a5608ae8b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java @@ -17,8 +17,10 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; import javax.servlet.ServletContext; @@ -26,6 +28,7 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.Endpoint; +import javax.websocket.Extension; import org.apache.tomcat.websocket.server.WsServerContainer; import org.springframework.http.server.ServerHttpRequest; @@ -33,6 +36,7 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration; import org.springframework.web.socket.server.endpoint.ServletServerContainerFactoryBean; @@ -50,12 +54,26 @@ import org.springframework.web.socket.server.endpoint.ServletServerContainerFact */ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy { + private List availableExtensions; @Override public String[] getSupportedVersions() { return new String[] { "13" }; } + @Override + public List getAvailableExtensions(ServerHttpRequest request) { + + if(this.availableExtensions == null) { + this.availableExtensions = new ArrayList(); + HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); + for(Extension extension : getContainer(servletRequest).getInstalledExtensions()) { + this.availableExtensions.add(parseStandardExtension(extension)); + } + } + return this.availableExtensions; + } + @Override public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, String acceptedProtocol, Endpoint endpoint) throws HandshakeFailureException { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index 5b76febfaf..474d94dcd2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.List; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -30,6 +31,7 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; @@ -66,6 +68,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private String acceptedProtocol; + private List extensions; public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler, Map handshakeAttributes) { @@ -116,6 +119,9 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { this.remoteAddress = remoteAddress; } + @Override + public List getExtensions() { return this.extensions; } + /** * Unlike WebSocket where sub-protocol negotiation is part of the * initial handshake, in HTTP transports the same negotiation must @@ -152,6 +158,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { this.principal = request.getPrincipal(); this.localAddress = request.getLocalAddress(); this.remoteAddress = request.getRemoteAddress(); + this.extensions = WebSocketExtension.parseHeaders(response.getHeaders().getSecWebSocketExtensions()); try { delegateConnectionEstablished(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java index 8474f108c0..e66af40e7e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.List; import java.util.Map; import org.springframework.http.HttpHeaders; @@ -27,6 +28,7 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.NativeWebSocketSession; @@ -89,6 +91,12 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession return this.wsSession.getAcceptedProtocol(); } + @Override + public List getExtensions() { + checkDelegateSessionInitialized(); + return this.wsSession.getExtensions(); + } + private void checkDelegateSessionInitialized() { Assert.state(this.wsSession != null, "WebSocketSession not yet initialized"); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java new file mode 100644 index 0000000000..2a674af173 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import org.hamcrest.Matchers; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +/** + * Test fixture for {@link WebSocketExtension} + * @author Brian Clozel + */ +public class WebSocketExtensionTests { + + @Test + public void parseHeaderSingle() { + List extensions = WebSocketExtension.parseHeader("x-test-extension ; foo=bar"); + assertThat(extensions, Matchers.hasSize(1)); + WebSocketExtension extension = extensions.get(0); + assertEquals("x-test-extension", extension.getName()); + assertEquals(1, extension.getParameters().size()); + assertEquals("bar", extension.getParameters().get("foo")); + } + + @Test + public void parseHeaderMultiple() { + List extensions = WebSocketExtension.parseHeader("x-foo-extension, x-bar-extension"); + assertThat(extensions, Matchers.hasSize(2)); + assertEquals("x-foo-extension", extensions.get(0).getName()); + assertEquals("x-bar-extension", extensions.get(1).getName()); + } + + @Test + public void parseHeaders() { + List extensions = new ArrayList(); + extensions.add("x-foo-extension, x-bar-extension"); + extensions.add("x-test-extension"); + List parsedExtensions = WebSocketExtension.parseHeaders(extensions); + assertThat(parsedExtensions, Matchers.hasSize(3)); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java index 1cee547aab..9f145cdd1e 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java @@ -26,6 +26,7 @@ import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; @@ -58,6 +59,8 @@ public class TestSockJsSession extends AbstractSockJsSession { private String subProtocol; + private List extensions = new ArrayList(); + public TestSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler wsHandler, Map attributes) { @@ -118,7 +121,7 @@ public class TestSockJsSession extends AbstractSockJsSession { } /** - * @param remoteAddress the remoteAddress to set + * @param localAddress the remoteAddress to set */ public void setLocalAddress(InetSocketAddress localAddress) { this.localAddress = localAddress; @@ -148,6 +151,18 @@ public class TestSockJsSession extends AbstractSockJsSession { this.subProtocol = protocol; } + /** + * @return the extensions + */ + @Override + public List getExtensions() { return this.extensions; } + + /** + * + * @param extensions the extensions to set + */ + public void setExtensions(List extensions) { this.extensions = extensions; } + public CloseStatus getCloseStatus() { return this.closeStatus; } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java index 6f16888e68..17935ac62d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java @@ -27,6 +27,7 @@ import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; @@ -51,6 +52,8 @@ public class TestWebSocketSession implements WebSocketSession { private String protocol; + private List extensions = new ArrayList(); + private boolean open; private final List> messages = new ArrayList<>(); @@ -149,7 +152,7 @@ public class TestWebSocketSession implements WebSocketSession { } /** - * @param remoteAddress the remoteAddress to set + * @param localAddress the remoteAddress to set */ public void setLocalAddress(InetSocketAddress localAddress) { this.localAddress = localAddress; @@ -184,6 +187,18 @@ public class TestWebSocketSession implements WebSocketSession { this.protocol = protocol; } + /** + * @return the extensions + */ + @Override + public List getExtensions() { return this.extensions; } + + /** + * + * @param extensions the extensions to set + */ + public void setExtensions(List extensions) { this.extensions = extensions; } + /** * @return the open */