diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java index 1d3dcb30a0..fa780ed94d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/websocket/SubProtocolHandler.java @@ -69,7 +69,7 @@ public interface SubProtocolHandler { /** * Resolve the session id from the given message or return {@code null}. * - * @param the message to resolve the session id from + * @param message the message to resolve the session id from */ String resolveSessionId(Message message); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java new file mode 100644 index 0000000000..399a3eca81 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketExtension.java @@ -0,0 +1,167 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket; + +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.StringUtils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * WebSocket Protocol extension. + * Adds new protocol features to the WebSocket protocol; the extensions used + * within a session are negotiated during the handshake phase: + * + * + *

WebSocket Extension HTTP headers may include parameters and follow + * RFC 2616 Section 4.2 + * specifications.

+ * + *

Note that the order of extensions in HTTP headers defines their order of execution, + * e.g. extensions "foo, bar" will be executed as "bar(foo(message))".

+ * + * @author Brian Clozel + * @since 4.0 + * @see + * WebSocket Protocol Extensions, RFC 6455 - Section 9 + */ +public class WebSocketExtension { + + private final String name; + + private final Map parameters; + + public WebSocketExtension(String name) { + this(name,null); + } + + public WebSocketExtension(String name, Map parameters) { + Assert.hasLength(name, "extension name must not be empty"); + this.name = name; + if (!CollectionUtils.isEmpty(parameters)) { + Map m = new LinkedCaseInsensitiveMap(parameters.size(), Locale.ENGLISH); + m.putAll(parameters); + this.parameters = Collections.unmodifiableMap(m); + } + else { + this.parameters = Collections.emptyMap(); + } + } + + /** + * @return the name of the extension + */ + public String getName() { + return this.name; + } + + /** + * @return the parameters of the extension + */ + public Map getParameters() { + return this.parameters; + } + + /** + * Parse a list of raw WebSocket extension headers + */ + public static List parseHeaders(List headers) { + if (headers == null || headers.isEmpty()) { + return Collections.emptyList(); + } + else { + List result = new ArrayList(headers.size()); + for (String header : headers) { + result.addAll(parseHeader(header)); + } + return result; + } + } + + /** + * Parse a raw WebSocket extension header + */ + public static List parseHeader(String header) { + if (header == null || !StringUtils.hasText(header)) { + return Collections.emptyList(); + } + else { + List result = new ArrayList(); + for(String token : header.split(",")) { + result.add(parse(token)); + } + return result; + } + } + + private static WebSocketExtension parse(String extension) { + Assert.doesNotContain(extension,",","this string contains multiple extension declarations"); + String[] parts = StringUtils.tokenizeToStringArray(extension, ";"); + String name = parts[0].trim(); + + Map parameters = null; + if (parts.length > 1) { + parameters = new LinkedHashMap(parts.length - 1); + for (int i = 1; i < parts.length; i++) { + String parameter = parts[i]; + int eqIndex = parameter.indexOf('='); + if (eqIndex != -1) { + String attribute = parameter.substring(0, eqIndex); + String value = parameter.substring(eqIndex + 1, parameter.length()); + parameters.put(attribute, value); + } + } + } + + return new WebSocketExtension(name,parameters); + } + + /** + * Convert a list of WebSocketExtensions to a list of String, + * which is convenient for native HTTP headers. + */ + public static List toStringList(List extensions) { + List result = new ArrayList(extensions.size()); + for(WebSocketExtension extension : extensions) { + result.add(extension.toString()); + } + return result; + } + + @Override + public String toString() { + StringBuilder str = new StringBuilder(); + str.append(this.name); + for (String param : parameters.keySet()) { + str.append(';'); + str.append(param); + str.append('='); + str.append(this.parameters.get(param)); + } + return str.toString(); + } +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java index b935e21889..a4634ddc39 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.List; import java.util.Map; import org.springframework.http.HttpHeaders; @@ -79,6 +80,12 @@ public interface WebSocketSession { */ String getAcceptedProtocol(); + /** + * Return the negotiated extensions or {@code null} if none was specified or + * negotiated successfully. + */ + List getExtensions(); + /** * Return whether the connection is still open. */ diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java index b7bc0a8130..62222e394e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSession.java @@ -20,8 +20,11 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.ArrayList; +import java.util.List; import java.util.Map; +import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig; import org.springframework.http.HttpHeaders; import org.springframework.util.ObjectUtils; import org.springframework.web.socket.BinaryMessage; @@ -29,6 +32,7 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.PongMessage; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketSession; /** @@ -42,6 +46,8 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion extensions; + private final Principal principal; @@ -104,6 +110,18 @@ public class JettyWebSocketSession extends AbstractWebSocketSesssion getExtensions() { + checkNativeSessionInitialized(); + if(this.extensions == null) { + this.extensions = new ArrayList(); + for(ExtensionConfig ext : getNativeSession().getUpgradeResponse().getExtensions()) { + this.extensions.add(new WebSocketExtension(ext.getName(),ext.getParameters())); + } + } + return this.extensions; + } + @Override public boolean isOpen() { return ((getNativeSession() != null) && getNativeSession().isOpen()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java index 7ddd3a8718..6a9fa9be8c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSession.java @@ -20,10 +20,14 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import javax.websocket.CloseReason; import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.Extension; import org.springframework.http.HttpHeaders; import org.springframework.util.StringUtils; @@ -32,6 +36,7 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.PongMessage; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketSession; /** @@ -48,6 +53,8 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion extensions; + /** * Class constructor. @@ -108,6 +115,23 @@ public class StandardWebSocketSession extends AbstractWebSocketSesssion getExtensions() { + checkNativeSessionInitialized(); + if(this.extensions == null) { + List nativeExtensions = getNativeSession().getNegotiatedExtensions(); + this.extensions = new ArrayList(nativeExtensions.size()); + for(Extension nativeExtension : nativeExtensions) { + Map parameters = new HashMap(); + for (Extension.Parameter param : nativeExtension.getParameters()) { + parameters.put(param.getName(),param.getValue()); + } + this.extensions.add(new WebSocketExtension(nativeExtension.getName(),parameters)); + } + } + return this.extensions; + } + @Override public boolean isOpen() { return ((getNativeSession() != null) && getNativeSession().isOpen()); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java index 6c0bfa0e84..f4b02b3a01 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java @@ -32,13 +32,15 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; /** * A default {@link HandshakeHandler} implementation. Performs initial validation of the * WebSocket handshake request -- possibly rejecting it through the appropriate HTTP * status code -- while also allowing sub-classes to override various parts of the - * negotiation process (e.g. origin validation, sub-protocol negotiation, etc). + * negotiation process (e.g. origin validation, sub-protocol negotiation, + * extensions negotiation, etc). * *

* If the negotiation succeeds, the actual upgrade is delegated to a server-specific @@ -188,6 +190,13 @@ public class DefaultHandshakeHandler implements HandshakeHandler { logger.debug("Upgrading request, sub-protocol=" + subProtocol); } + List requestedExtensions = WebSocketExtension + .parseHeaders(request.getHeaders().getSecWebSocketExtensions()); + + List filteredExtensions = filterRequestedExtensions(requestedExtensions, + this.requestUpgradeStrategy.getAvailableExtensions(request)); + request.getHeaders().setSecWebSocketExtensions(WebSocketExtension.toStringList(filteredExtensions)); + this.requestUpgradeStrategy.upgrade(request, response, subProtocol, wsHandler, attributes); return true; @@ -254,4 +263,27 @@ public class DefaultHandshakeHandler implements HandshakeHandler { return null; } + /** + * Filter the list of WebSocket Extensions requested by the client. + * Since the negotiation process happens during the upgrade phase within the server + * implementation, one can customize the applied extensions only by filtering the + * requested extensions by the client. + * + *

The default implementation of this method doesn't filter any of the extensions + * requested by the client. + * @param requestedExtensions the list of extensions requested by the client + * @param supportedExtensions the list of extensions supported by the server + * @return the filtered list of requested extensions + */ + protected List filterRequestedExtensions(List requestedExtensions, + List supportedExtensions) { + + if (requestedExtensions != null) { + if (logger.isDebugEnabled()) { + logger.debug("Requested extension(s): " + requestedExtensions + + ", supported extension(s): " + supportedExtensions); + } + } + return requestedExtensions; + } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java index b781426445..89b197fe38 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java @@ -16,10 +16,12 @@ package org.springframework.web.socket.server; +import java.util.List; import java.util.Map; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; /** @@ -36,6 +38,12 @@ public interface RequestUpgradeStrategy { */ String[] getSupportedVersions(); + /** + * @return the list of available WebSocket protocol extensions, + * implemented by the underlying WebSocket server. + */ + List getAvailableExtensions(ServerHttpRequest request); + /** * Perform runtime specific steps to complete the upgrade. Invoked after successful * negotiation of the handshake request. diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java index 1d41f06bb1..221ac25634 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/config/WebSocketConfigurer.java @@ -16,8 +16,7 @@ package org.springframework.web.socket.server.config; -import org.eclipse.jetty.websocket.server.WebSocketHandler; - +import org.springframework.web.socket.WebSocketHandler; /** diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java index 6f0b783887..7349615f50 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractGlassFishRequestUpgradeStrategy.java @@ -19,15 +19,21 @@ package org.springframework.web.socket.server.support; import java.io.IOException; import java.lang.reflect.Constructor; import java.net.URI; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Random; +import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.DeploymentException; import javax.websocket.Endpoint; +import javax.websocket.Extension; +import javax.websocket.server.ServerContainer; +import org.apache.tomcat.websocket.server.WsServerContainer; import org.glassfish.tyrus.core.ComponentProviderService; import org.glassfish.tyrus.core.EndpointWrapper; import org.glassfish.tyrus.core.ErrorCollector; @@ -47,6 +53,7 @@ import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration; import org.springframework.web.socket.server.endpoint.ServletServerContainerFactoryBean; @@ -66,11 +73,26 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt private final static Random random = new Random(); + private List availableExtensions; + @Override public String[] getSupportedVersions() { return StringUtils.commaDelimitedListToStringArray(Version.getSupportedWireProtocolVersions()); } + @Override + public List getAvailableExtensions(ServerHttpRequest request) { + + if(this.availableExtensions == null) { + this.availableExtensions = new ArrayList(); + HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); + for(Extension extension : getContainer(servletRequest).getInstalledExtensions()) { + this.availableExtensions.add(parseStandardExtension(extension)); + } + } + return this.availableExtensions; + } + @Override public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, Endpoint endpoint) throws HandshakeFailureException { @@ -103,6 +125,13 @@ public abstract class AbstractGlassFishRequestUpgradeStrategy extends AbstractSt } } + public ServerContainer getContainer(HttpServletRequest servletRequest) { + + String attributeName = "javax.websocket.server.ServerContainer"; + ServletContext servletContext = servletRequest.getServletContext(); + return (ServerContainer)servletContext.getAttribute(attributeName); + } + private boolean performUpgrade(HttpServletRequest request, HttpServletResponse response, HttpHeaders headers, WebSocketApplication wsApp) throws IOException { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java index 925540c3fc..0f62add358 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java @@ -17,15 +17,18 @@ package org.springframework.web.socket.server.support; import java.net.InetSocketAddress; +import java.util.HashMap; import java.util.Map; import javax.websocket.Endpoint; +import javax.websocket.Extension; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.StandardWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.StandardWebSocketSession; @@ -62,4 +65,12 @@ public abstract class AbstractStandardUpgradeStrategy implements RequestUpgradeS protected abstract void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, Endpoint endpoint) throws HandshakeFailureException; + protected WebSocketExtension parseStandardExtension(Extension extension) { + Map params = new HashMap(extension.getParameters().size()); + for(Extension.Parameter param : extension.getParameters()) { + params.put(param.getName(),param.getValue()); + } + return new WebSocketExtension(extension.getName(),params); + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java index 14f6d1ce6b..3064adb417 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java @@ -17,6 +17,8 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import javax.servlet.http.HttpServletRequest; @@ -34,6 +36,7 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.JettyWebSocketHandlerAdapter; import org.springframework.web.socket.adapter.JettyWebSocketSession; @@ -54,6 +57,8 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { private WebSocketServerFactory factory; + private List availableExtensions; + /** * Default constructor that creates {@link WebSocketServerFactory} through its default @@ -92,6 +97,17 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { return new String[] { String.valueOf(HandshakeRFC6455.VERSION) }; } + @Override + public List getAvailableExtensions(ServerHttpRequest request) { + if(this.availableExtensions == null) { + this.availableExtensions = new ArrayList(); + for(String extensionName : this.factory.getExtensionFactory().getExtensionNames()) { + this.availableExtensions.add(new WebSocketExtension(extensionName)); + } + } + return this.availableExtensions; + } + @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String protocol, WebSocketHandler wsHandler, Map attrs) throws HandshakeFailureException { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java index 4a96dff264..7a5608ae8b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/TomcatRequestUpgradeStrategy.java @@ -17,8 +17,10 @@ package org.springframework.web.socket.server.support; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; import javax.servlet.ServletContext; @@ -26,6 +28,7 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.Endpoint; +import javax.websocket.Extension; import org.apache.tomcat.websocket.server.WsServerContainer; import org.springframework.http.server.ServerHttpRequest; @@ -33,6 +36,7 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.util.Assert; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.server.HandshakeFailureException; import org.springframework.web.socket.server.endpoint.ServerEndpointRegistration; import org.springframework.web.socket.server.endpoint.ServletServerContainerFactoryBean; @@ -50,12 +54,26 @@ import org.springframework.web.socket.server.endpoint.ServletServerContainerFact */ public class TomcatRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy { + private List availableExtensions; @Override public String[] getSupportedVersions() { return new String[] { "13" }; } + @Override + public List getAvailableExtensions(ServerHttpRequest request) { + + if(this.availableExtensions == null) { + this.availableExtensions = new ArrayList(); + HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); + for(Extension extension : getContainer(servletRequest).getInstalledExtensions()) { + this.availableExtensions.add(parseStandardExtension(extension)); + } + } + return this.availableExtensions; + } + @Override public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response, String acceptedProtocol, Endpoint endpoint) throws HandshakeFailureException { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index 5b76febfaf..474d94dcd2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.List; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; @@ -30,6 +31,7 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; @@ -66,6 +68,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private String acceptedProtocol; + private List extensions; public AbstractHttpSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler wsHandler, Map handshakeAttributes) { @@ -116,6 +119,9 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { this.remoteAddress = remoteAddress; } + @Override + public List getExtensions() { return this.extensions; } + /** * Unlike WebSocket where sub-protocol negotiation is part of the * initial handshake, in HTTP transports the same negotiation must @@ -152,6 +158,7 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { this.principal = request.getPrincipal(); this.localAddress = request.getLocalAddress(); this.remoteAddress = request.getRemoteAddress(); + this.extensions = WebSocketExtension.parseHeaders(response.getHeaders().getSecWebSocketExtensions()); try { delegateConnectionEstablished(); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java index 8474f108c0..e66af40e7e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.security.Principal; +import java.util.List; import java.util.Map; import org.springframework.http.HttpHeaders; @@ -27,6 +28,7 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.NativeWebSocketSession; @@ -89,6 +91,12 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession return this.wsSession.getAcceptedProtocol(); } + @Override + public List getExtensions() { + checkDelegateSessionInitialized(); + return this.wsSession.getExtensions(); + } + private void checkDelegateSessionInitialized() { Assert.state(this.wsSession != null, "WebSocketSession not yet initialized"); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java new file mode 100644 index 0000000000..2a674af173 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketExtensionTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +import org.hamcrest.Matchers; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +/** + * Test fixture for {@link WebSocketExtension} + * @author Brian Clozel + */ +public class WebSocketExtensionTests { + + @Test + public void parseHeaderSingle() { + List extensions = WebSocketExtension.parseHeader("x-test-extension ; foo=bar"); + assertThat(extensions, Matchers.hasSize(1)); + WebSocketExtension extension = extensions.get(0); + assertEquals("x-test-extension", extension.getName()); + assertEquals(1, extension.getParameters().size()); + assertEquals("bar", extension.getParameters().get("foo")); + } + + @Test + public void parseHeaderMultiple() { + List extensions = WebSocketExtension.parseHeader("x-foo-extension, x-bar-extension"); + assertThat(extensions, Matchers.hasSize(2)); + assertEquals("x-foo-extension", extensions.get(0).getName()); + assertEquals("x-bar-extension", extensions.get(1).getName()); + } + + @Test + public void parseHeaders() { + List extensions = new ArrayList(); + extensions.add("x-foo-extension, x-bar-extension"); + extensions.add("x-test-extension"); + List parsedExtensions = WebSocketExtension.parseHeaders(extensions); + assertThat(parsedExtensions, Matchers.hasSize(3)); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java index 1cee547aab..9f145cdd1e 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/TestSockJsSession.java @@ -26,6 +26,7 @@ import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.support.frame.SockJsFrame; @@ -58,6 +59,8 @@ public class TestSockJsSession extends AbstractSockJsSession { private String subProtocol; + private List extensions = new ArrayList(); + public TestSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler wsHandler, Map attributes) { @@ -118,7 +121,7 @@ public class TestSockJsSession extends AbstractSockJsSession { } /** - * @param remoteAddress the remoteAddress to set + * @param localAddress the remoteAddress to set */ public void setLocalAddress(InetSocketAddress localAddress) { this.localAddress = localAddress; @@ -148,6 +151,18 @@ public class TestSockJsSession extends AbstractSockJsSession { this.subProtocol = protocol; } + /** + * @return the extensions + */ + @Override + public List getExtensions() { return this.extensions; } + + /** + * + * @param extensions the extensions to set + */ + public void setExtensions(List extensions) { this.extensions = extensions; } + public CloseStatus getCloseStatus() { return this.closeStatus; } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java index 6f16888e68..17935ac62d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/support/TestWebSocketSession.java @@ -27,6 +27,7 @@ import java.util.Map; import org.springframework.http.HttpHeaders; import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; @@ -51,6 +52,8 @@ public class TestWebSocketSession implements WebSocketSession { private String protocol; + private List extensions = new ArrayList(); + private boolean open; private final List> messages = new ArrayList<>(); @@ -149,7 +152,7 @@ public class TestWebSocketSession implements WebSocketSession { } /** - * @param remoteAddress the remoteAddress to set + * @param localAddress the remoteAddress to set */ public void setLocalAddress(InetSocketAddress localAddress) { this.localAddress = localAddress; @@ -184,6 +187,18 @@ public class TestWebSocketSession implements WebSocketSession { this.protocol = protocol; } + /** + * @return the extensions + */ + @Override + public List getExtensions() { return this.extensions; } + + /** + * + * @param extensions the extensions to set + */ + public void setExtensions(List extensions) { this.extensions = extensions; } + /** * @return the open */