From c86f711de5dbd2a601a70390b24f926de7efb485 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 21 Apr 2014 11:18:54 -0400 Subject: [PATCH 1/7] Polish SockJsFrame --- .../web/socket/sockjs/frame/SockJsFrame.java | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java index fbd5ca13bb..97a620170d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java @@ -21,7 +21,8 @@ import java.nio.charset.Charset; import org.springframework.util.Assert; /** - * Represents a SockJS frame and provides factory methods for creating SockJS frames. + * Represents a SockJS frame. Provides factory methods to create SockJS frames on + * the server side. * * @author Rossen Stoyanchev * @since 4.0 @@ -30,21 +31,29 @@ public class SockJsFrame { private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); - private static final SockJsFrame openFrame = new SockJsFrame("o"); + private static final SockJsFrame OPEN_FRAME = new SockJsFrame("o"); - private static final SockJsFrame heartbeatFrame = new SockJsFrame("h"); + private static final SockJsFrame HEARTBEAT_FRAME = new SockJsFrame("h"); - private static final SockJsFrame closeGoAwayFrame = closeFrame(3000, "Go away!"); + private static final SockJsFrame CLOSE_GO_AWAY_FRAME = closeFrame(3000, "Go away!"); - private static final SockJsFrame closeAnotherConnectionOpenFrame = closeFrame(2010, "Another connection still open"); + private static final SockJsFrame CLOSE_ANOTHER_CONNECTION_OPEN_FRAME = closeFrame(2010, "Another connection still open"); + private final String content; + + + public SockJsFrame(String content) { + Assert.notNull("Content must not be null"); + this.content = content; + } + public static SockJsFrame openFrame() { - return openFrame; + return OPEN_FRAME; } public static SockJsFrame heartbeatFrame() { - return heartbeatFrame; + return HEARTBEAT_FRAME; } public static SockJsFrame messageFrame(SockJsMessageCodec codec, String... messages) { @@ -53,11 +62,11 @@ public class SockJsFrame { } public static SockJsFrame closeFrameGoAway() { - return closeGoAwayFrame; + return CLOSE_GO_AWAY_FRAME; } public static SockJsFrame closeFrameAnotherConnectionOpen() { - return closeAnotherConnectionOpenFrame; + return CLOSE_ANOTHER_CONNECTION_OPEN_FRAME; } public static SockJsFrame closeFrame(int code, String reason) { @@ -65,15 +74,6 @@ public class SockJsFrame { } - private final String content; - - - public SockJsFrame(String content) { - Assert.notNull("Content must not be null"); - this.content = content; - } - - public String getContent() { return this.content; } @@ -82,6 +82,7 @@ public class SockJsFrame { return this.content.getBytes(UTF8_CHARSET); } + @Override public boolean equals(Object other) { if (this == other) { From c14ba1a0ff72ed11eaa5fc7d808f6f3a3fbe2916 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 21 Apr 2014 13:57:09 -0400 Subject: [PATCH 2/7] Add SockJsFrameType enum SPR-10797 --- .../frame/DefaultSockJsFrameFormat.java | 13 ++- .../web/socket/sockjs/frame/SockJsFrame.java | 68 +++++++++++-- .../sockjs/frame/SockJsFrameFormat.java | 13 ++- .../socket/sockjs/frame/SockJsFrameType.java | 29 ++++++ .../AbstractHttpSendingTransportHandler.java | 6 +- .../session/AbstractHttpSockJsSession.java | 6 +- .../socket/sockjs/frame/SockJsFrameTests.java | 99 +++++++++++++++++++ .../HttpSendingTransportHandlerTests.java | 12 +-- 8 files changed, 219 insertions(+), 27 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameType.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/sockjs/frame/SockJsFrameTests.java diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/DefaultSockJsFrameFormat.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/DefaultSockJsFrameFormat.java index 56220f5f46..fbb92f0944 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/DefaultSockJsFrameFormat.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/DefaultSockJsFrameFormat.java @@ -19,6 +19,10 @@ package org.springframework.web.socket.sockjs.frame; import org.springframework.util.Assert; /** + * A default implementation of + * {@link org.springframework.web.socket.sockjs.frame.SockJsFrameFormat} that relies + * on {@link java.lang.String#format(String, Object...)}.. + * * @author Rossen Stoyanchev * @since 4.0 */ @@ -33,14 +37,9 @@ public class DefaultSockJsFrameFormat implements SockJsFrameFormat { } - /** - * @param frame the SockJs frame. - * @return new SockJsFrame instance with the formatted content - */ @Override - public SockJsFrame format(SockJsFrame frame) { - String content = String.format(this.format, preProcessContent(frame.getContent())); - return new SockJsFrame(content); + public String format(SockJsFrame frame) { + return String.format(this.format, preProcessContent(frame.getContent())); } protected String preProcessContent(String content) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java index 97a620170d..7e07bf7ced 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java @@ -18,18 +18,17 @@ package org.springframework.web.socket.sockjs.frame; import java.nio.charset.Charset; -import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** - * Represents a SockJS frame. Provides factory methods to create SockJS frames on - * the server side. + * Represents a SockJS frame. Provides factory methods to create SockJS frames. * * @author Rossen Stoyanchev * @since 4.0 */ public class SockJsFrame { - private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + public static final Charset CHARSET = Charset.forName("UTF-8"); private static final SockJsFrame OPEN_FRAME = new SockJsFrame("o"); @@ -40,12 +39,40 @@ public class SockJsFrame { private static final SockJsFrame CLOSE_ANOTHER_CONNECTION_OPEN_FRAME = closeFrame(2010, "Another connection still open"); + private final SockJsFrameType type; + private final String content; + /** + * Create a new instance frame with the given frame content. + * @param content the content, must be a non-empty and represent a valid SockJS frame + */ public SockJsFrame(String content) { - Assert.notNull("Content must not be null"); - this.content = content; + StringUtils.hasText(content); + if ("o".equals(content)) { + this.type = SockJsFrameType.OPEN; + this.content = content; + } + else if ("h".equals(content)) { + this.type = SockJsFrameType.HEARTBEAT; + this.content = content; + } + else if (content.charAt(0) == 'a') { + this.type = SockJsFrameType.MESSAGE; + this.content = (content.length() > 1 ? content : "a[]"); + } + else if (content.charAt(0) == 'm') { + this.type = SockJsFrameType.MESSAGE; + this.content = (content.length() > 1 ? content : "null"); + } + else if (content.charAt(0) == 'c') { + this.type = SockJsFrameType.CLOSE; + this.content = (content.length() > 1 ? content : "c[]"); + } + else { + throw new IllegalArgumentException("Unexpected SockJS frame type in content=\"" + content + "\""); + } } public static SockJsFrame openFrame() { @@ -74,12 +101,39 @@ public class SockJsFrame { } + /** + * Return the SockJS frame type. + */ + public SockJsFrameType getType() { + return this.type; + } + + /** + * Return the SockJS frame content, never {@code null}. + */ public String getContent() { return this.content; } + /** + * Return the SockJS frame content as a byte array. + */ public byte[] getContentBytes() { - return this.content.getBytes(UTF8_CHARSET); + return this.content.getBytes(CHARSET); + } + + /** + * Return data contained in a SockJS "message" and "close" frames. Otherwise + * for SockJS "open" and "close" frames, which do not contain data, return + * {@code null}. + */ + public String getFrameData() { + if (SockJsFrameType.OPEN == getType() || SockJsFrameType.HEARTBEAT == getType()) { + return null; + } + else { + return getContent().substring(1); + } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameFormat.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameFormat.java index 858d2476ed..26376e468c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameFormat.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameFormat.java @@ -17,11 +17,22 @@ package org.springframework.web.socket.sockjs.frame; /** + * Applies a transport-specific format to the content of a SockJS frame resulting + * in a content that can be written out. Primarily for use in HTTP server-side + * transports that push data. + * + *

Formatting may vary from simply appending a new line character for XHR + * polling and streaming transports, to a jsonp-style callback function, + * surrounding script tags, and more. + * + *

For the various SockJS frame formats in use, see implementations of + * {@link org.springframework.web.socket.sockjs.transport.handler.AbstractHttpSendingTransportHandler#getFrameFormat(org.springframework.http.server.ServerHttpRequest) AbstractHttpSendingTransportHandler.getFrameFormat} + * * @author Rossen Stoyanchev * @since 4.0 */ public interface SockJsFrameFormat { - SockJsFrame format(SockJsFrame frame); + String format(SockJsFrame frame); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameType.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameType.java new file mode 100644 index 0000000000..eb6fc66460 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameType.java @@ -0,0 +1,29 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.frame; + +/** + * SockJS frame types. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public enum SockJsFrameType { + + OPEN, HEARTBEAT, MESSAGE, CLOSE + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java index 0bba697e10..0acd733767 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java @@ -82,12 +82,12 @@ public abstract class AbstractHttpSendingTransportHandler extends AbstractTransp } else { logger.debug("another " + getTransportType() + " connection still open: " + sockJsSession); - SockJsFrame frame = getFrameFormat(request).format(SockJsFrame.closeFrameAnotherConnectionOpen()); + String formattedFrame = getFrameFormat(request).format(SockJsFrame.closeFrameAnotherConnectionOpen()); try { - response.getBody().write(frame.getContentBytes()); + response.getBody().write(formattedFrame.getBytes(SockJsFrame.CHARSET)); } catch (IOException ex) { - throw new SockJsException("Failed to send " + frame, sockJsSession.getId(), ex); + throw new SockJsException("Failed to send " + formattedFrame, sockJsSession.getId(), ex); } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index f433ea5f27..bd3dccd613 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -351,11 +351,11 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { @Override protected void writeFrameInternal(SockJsFrame frame) throws IOException { if (isActive()) { - frame = this.frameFormat.format(frame); + String formattedFrame = this.frameFormat.format(frame); if (logger.isTraceEnabled()) { - logger.trace("Writing " + frame); + logger.trace("Writing " + formattedFrame); } - getResponse().getBody().write(frame.getContentBytes()); + getResponse().getBody().write(formattedFrame.getBytes(SockJsFrame.CHARSET)); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/frame/SockJsFrameTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/frame/SockJsFrameTests.java new file mode 100644 index 0000000000..d32b49d51d --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/frame/SockJsFrameTests.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.frame; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** + * Unit tests for {@link org.springframework.web.socket.sockjs.frame.SockJsFrame}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class SockJsFrameTests { + + + @Test + public void openFrame() { + SockJsFrame frame = SockJsFrame.openFrame(); + + assertEquals("o", frame.getContent()); + assertEquals(SockJsFrameType.OPEN, frame.getType()); + assertNull(frame.getFrameData()); + } + + @Test + public void heartbeatFrame() { + SockJsFrame frame = SockJsFrame.heartbeatFrame(); + + assertEquals("h", frame.getContent()); + assertEquals(SockJsFrameType.HEARTBEAT, frame.getType()); + assertNull(frame.getFrameData()); + } + + @Test + public void messageArrayFrame() { + SockJsFrame frame = SockJsFrame.messageFrame(new Jackson2SockJsMessageCodec(), "m1", "m2"); + + assertEquals("a[\"m1\",\"m2\"]", frame.getContent()); + assertEquals(SockJsFrameType.MESSAGE, frame.getType()); + assertEquals("[\"m1\",\"m2\"]", frame.getFrameData()); + } + + @Test + public void messageArrayFrameEmpty() { + SockJsFrame frame = new SockJsFrame("a"); + + assertEquals("a[]", frame.getContent()); + assertEquals(SockJsFrameType.MESSAGE, frame.getType()); + assertEquals("[]", frame.getFrameData()); + + frame = new SockJsFrame("a[]"); + + assertEquals("a[]", frame.getContent()); + assertEquals(SockJsFrameType.MESSAGE, frame.getType()); + assertEquals("[]", frame.getFrameData()); + } + + @Test + public void closeFrame() { + SockJsFrame frame = SockJsFrame.closeFrame(3000, "Go Away!"); + + assertEquals("c[3000,\"Go Away!\"]", frame.getContent()); + assertEquals(SockJsFrameType.CLOSE, frame.getType()); + assertEquals("[3000,\"Go Away!\"]", frame.getFrameData()); + } + + @Test + public void closeFrameEmpty() { + SockJsFrame frame = new SockJsFrame("c"); + + assertEquals("c[]", frame.getContent()); + assertEquals(SockJsFrameType.CLOSE, frame.getType()); + assertEquals("[]", frame.getFrameData()); + + frame = new SockJsFrame("c[]"); + + assertEquals("c[]", frame.getContent()); + assertEquals(SockJsFrameType.CLOSE, frame.getType()); + assertEquals("[]", frame.getFrameData()); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java index 3797f9024b..dcab42610e 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java @@ -168,24 +168,24 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests SockJsFrame frame = SockJsFrame.openFrame(); SockJsFrameFormat format = new XhrPollingTransportHandler().getFrameFormat(this.request); - SockJsFrame formatted = format.format(frame); - assertEquals(frame.getContent() + "\n", formatted.getContent()); + String formatted = format.format(frame); + assertEquals(frame.getContent() + "\n", formatted); format = new XhrStreamingTransportHandler().getFrameFormat(this.request); formatted = format.format(frame); - assertEquals(frame.getContent() + "\n", formatted.getContent()); + assertEquals(frame.getContent() + "\n", formatted); format = new HtmlFileTransportHandler().getFrameFormat(this.request); formatted = format.format(frame); - assertEquals("\r\n", formatted.getContent()); + assertEquals("\r\n", formatted); format = new EventSourceTransportHandler().getFrameFormat(this.request); formatted = format.format(frame); - assertEquals("data: " + frame.getContent() + "\r\n\r\n", formatted.getContent()); + assertEquals("data: " + frame.getContent() + "\r\n\r\n", formatted); format = new JsonpPollingTransportHandler().getFrameFormat(this.request); formatted = format.format(frame); - assertEquals("callback(\"" + frame.getContent() + "\");\r\n", formatted.getContent()); + assertEquals("callback(\"" + frame.getContent() + "\");\r\n", formatted); } } From dc1d85d0459b19137687d1b0449ac2c0006a079e Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Thu, 19 Jun 2014 11:11:41 -0400 Subject: [PATCH 3/7] Support registering Filters in WebSocket test servers --- .../web/socket/JettyWebSocketTestServer.java | 14 +++++++- .../web/socket/TomcatWebSocketTestServer.java | 26 +++++++++++++-- .../web/socket/UndertowTestServer.java | 33 ++++++++++++++++++- .../web/socket/WebSocketIntegrationTests.java | 20 ++++------- .../web/socket/WebSocketTestServer.java | 4 ++- .../WebSocketConfigurationTests.java | 14 ++++---- 6 files changed, 86 insertions(+), 25 deletions(-) diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java index 5d9af0aa31..9bea3bcd76 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java @@ -17,12 +17,17 @@ package org.springframework.web.socket; import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.FilterHolder; import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletHolder; import org.springframework.util.SocketUtils; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.DispatcherServlet; +import javax.servlet.DispatcherType; +import javax.servlet.Filter; +import java.util.EnumSet; + /** * Jetty based {@link WebSocketTestServer}. * @@ -46,13 +51,20 @@ public class JettyWebSocketTestServer implements WebSocketTestServer { } @Override - public void deployConfig(WebApplicationContext cxt) { + public void deployConfig(WebApplicationContext cxt, Filter... filters) { ServletContextHandler contextHandler = new ServletContextHandler(); ServletHolder servletHolder = new ServletHolder(new DispatcherServlet(cxt)); contextHandler.addServlet(servletHolder, "/"); + for (Filter filter : filters) { + contextHandler.addFilter(new FilterHolder(filter), "/*", getDispatcherTypes()); + } this.jettyServer.setHandler(contextHandler); } + private EnumSet getDispatcherTypes() { + return EnumSet.of(DispatcherType.REQUEST, DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.ASYNC); + } + @Override public void undeployConfig() { // Stopping jetty will undeploy the servlet diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/TomcatWebSocketTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/TomcatWebSocketTestServer.java index 2d1e855a84..caff98f687 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/TomcatWebSocketTestServer.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/TomcatWebSocketTestServer.java @@ -18,17 +18,23 @@ package org.springframework.web.socket; import java.io.File; import java.io.IOException; +import java.util.EnumSet; import org.apache.catalina.Context; import org.apache.catalina.connector.Connector; import org.apache.catalina.startup.Tomcat; import org.apache.coyote.http11.Http11NioProtocol; import org.apache.tomcat.util.descriptor.web.ApplicationListener; +import org.apache.tomcat.util.descriptor.web.FilterDef; +import org.apache.tomcat.util.descriptor.web.FilterMap; import org.apache.tomcat.websocket.server.WsContextListener; import org.springframework.util.SocketUtils; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.DispatcherServlet; +import javax.servlet.DispatcherType; +import javax.servlet.Filter; + /** * Tomcat based {@link WebSocketTestServer}. * @@ -82,11 +88,27 @@ public class TomcatWebSocketTestServer implements WebSocketTestServer { } @Override - public void deployConfig(WebApplicationContext wac) { + public void deployConfig(WebApplicationContext wac, Filter... filters) { this.context = this.tomcatServer.addContext("", System.getProperty("java.io.tmpdir")); this.context.addApplicationListener(WS_APPLICATION_LISTENER); - Tomcat.addServlet(context, "dispatcherServlet", new DispatcherServlet(wac)); + Tomcat.addServlet(this.context, "dispatcherServlet", new DispatcherServlet(wac)).setAsyncSupported(true); this.context.addServletMapping("/", "dispatcherServlet"); + for (Filter filter : filters) { + FilterDef filterDef = new FilterDef(); + filterDef.setFilterName(filter.getClass().getName()); + filterDef.setFilter(filter); + filterDef.setAsyncSupported("true"); + this.context.addFilterDef(filterDef); + FilterMap filterMap = new FilterMap(); + filterMap.setFilterName(filter.getClass().getName()); + filterMap.addURLPattern("/*"); + filterMap.setDispatcher("REQUEST,FORWARD,INCLUDE,ASYNC"); + this.context.addFilterMap(filterMap); + } + } + + private EnumSet getDispatcherTypes() { + return EnumSet.of(DispatcherType.REQUEST, DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.ASYNC); } @Override diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/UndertowTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/UndertowTestServer.java index 0980cbc76a..7d0cdd77b1 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/UndertowTestServer.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/UndertowTestServer.java @@ -16,12 +16,15 @@ package org.springframework.web.socket; +import javax.servlet.DispatcherType; +import javax.servlet.Filter; import javax.servlet.Servlet; import javax.servlet.ServletException; import io.undertow.Undertow; import io.undertow.servlet.api.DeploymentInfo; import io.undertow.servlet.api.DeploymentManager; +import io.undertow.servlet.api.FilterInfo; import io.undertow.servlet.api.InstanceFactory; import io.undertow.servlet.api.InstanceHandle; import io.undertow.websockets.jsr.WebSocketDeploymentInfo; @@ -56,7 +59,7 @@ public class UndertowTestServer implements WebSocketTestServer { } @Override - public void deployConfig(WebApplicationContext cxt) { + public void deployConfig(WebApplicationContext cxt, Filter... filters) { DispatcherServletInstanceFactory servletFactory = new DispatcherServletInstanceFactory(cxt); DeploymentInfo servletBuilder = deployment() @@ -66,6 +69,13 @@ public class UndertowTestServer implements WebSocketTestServer { .addServlet(servlet("DispatcherServlet", DispatcherServlet.class, servletFactory).addMapping("/")) .addServletContextAttribute(WebSocketDeploymentInfo.ATTRIBUTE_NAME, new WebSocketDeploymentInfo()); + for (final Filter filter : filters) { + String filterName = filter.getClass().getName(); + servletBuilder.addFilter(new FilterInfo(filterName, filter.getClass(), new FilterInstanceFactory(filter))); + for (DispatcherType type : DispatcherType.values()) { + servletBuilder.addFilterUrlMapping(filterName, "/*", type); + } + } this.manager = defaultContainer().addDeployment(servletBuilder); this.manager.deploy(); @@ -117,4 +127,25 @@ public class UndertowTestServer implements WebSocketTestServer { } } + private static class FilterInstanceFactory implements InstanceFactory { + + private final Filter filter; + + private FilterInstanceFactory(Filter filter) { + this.filter = filter; + } + + @Override + public InstanceHandle createInstance() throws InstantiationException { + return new InstanceHandle() { + @Override + public Filter getInstance() { + return filter; + } + @Override + public void release() {} + }; + } + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java index 8673abb394..04fe9651ec 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java @@ -58,24 +58,22 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTest @Override protected Class[] getAnnotatedConfigClasses() { - return new Class[] {TestWebSocketConfigurer.class}; + return new Class[] { TestConfig.class }; } @Test public void subProtocolNegotiation() throws Exception { WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); headers.setSecWebSocketProtocol("foo"); - - WebSocketSession session = this.webSocketClient.doHandshake( - new AbstractWebSocketHandler() {}, headers, new URI(getWsBaseUrl() + "/ws")).get(); - + URI url = new URI(getWsBaseUrl() + "/ws"); + WebSocketSession session = this.webSocketClient.doHandshake(new TextWebSocketHandler(), headers, url).get(); assertEquals("foo", session.getAcceptedProtocol()); } @Configuration @EnableWebSocket - static class TestWebSocketConfigurer implements WebSocketConfigurer { + static class TestConfig implements WebSocketConfigurer { @Autowired private DefaultHandshakeHandler handshakeHandler; // can't rely on classpath for server detection @@ -83,17 +81,13 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTest @Override public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { this.handshakeHandler.setSupportedProtocols("foo", "bar", "baz"); - registry.addHandler(serverHandler(), "/ws").setHandshakeHandler(this.handshakeHandler); + registry.addHandler(handler(), "/ws").setHandshakeHandler(this.handshakeHandler); } @Bean - public TestServerWebSocketHandler serverHandler() { - return new TestServerWebSocketHandler(); + public TextWebSocketHandler handler() { + return new TextWebSocketHandler(); } } - - private static class TestServerWebSocketHandler extends TextWebSocketHandler { - } - } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketTestServer.java index 0de0efa81c..aa841e61fe 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketTestServer.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketTestServer.java @@ -18,6 +18,8 @@ package org.springframework.web.socket; import org.springframework.web.context.WebApplicationContext; +import javax.servlet.Filter; + /** * Contract for a test server to use for WebSocket integration tests. * @@ -27,7 +29,7 @@ public interface WebSocketTestServer { int getPort(); - void deployConfig(WebApplicationContext cxt); + void deployConfig(WebApplicationContext cxt, Filter... filters); void undeployConfig(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationTests.java index 3a2e50bb10..4559712b28 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationTests.java @@ -57,7 +57,7 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes @Override protected Class[] getAnnotatedConfigClasses() { - return new Class[] {TestWebSocketConfigurer.class}; + return new Class[] { TestConfig.class }; } @Test @@ -65,7 +65,7 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes WebSocketSession session = this.webSocketClient.doHandshake( new AbstractWebSocketHandler() {}, getWsBaseUrl() + "/ws").get(); - TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class); + TestHandler serverHandler = this.wac.getBean(TestHandler.class); assertTrue(serverHandler.connectLatch.await(2, TimeUnit.SECONDS)); session.close(); @@ -76,7 +76,7 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes WebSocketSession session = this.webSocketClient.doHandshake( new AbstractWebSocketHandler() {}, getWsBaseUrl() + "/sockjs/websocket").get(); - TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class); + TestHandler serverHandler = this.wac.getBean(TestHandler.class); assertTrue(serverHandler.connectLatch.await(2, TimeUnit.SECONDS)); session.close(); @@ -85,7 +85,7 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes @Configuration @EnableWebSocket - static class TestWebSocketConfigurer implements WebSocketConfigurer { + static class TestConfig implements WebSocketConfigurer { @Autowired private HandshakeHandler handshakeHandler; // can't rely on classpath for server detection @@ -99,12 +99,12 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes } @Bean - public TestWebSocketHandler serverHandler() { - return new TestWebSocketHandler(); + public TestHandler serverHandler() { + return new TestHandler(); } } - private static class TestWebSocketHandler extends AbstractWebSocketHandler { + private static class TestHandler extends AbstractWebSocketHandler { private CountDownLatch connectLatch = new CountDownLatch(1); From e82df99a22b7f511fc8e2c92a8a206e06d16bd62 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 28 Apr 2014 22:01:14 -0400 Subject: [PATCH 4/7] Add SockJS client This change adds a new implementation of WebSocketClient that can connect to a SockJS server using one of the SockJS transports "websocket", "xhr_streaming", or "xhr". From a client perspective there is no implementation difference between "xhr_streaming" and "xhr". Just keep receiving and when the response is complete, start over. Other SockJS transports are browser specific and therefore not relevant in Java ("eventsource", "htmlfile" or iframe based variations). The client loosely mimics the behavior of the JavaScript SockJS client. First it sends an info request to find the server capabilities, then it tries to connect with each configured transport, falling back, or forcing a timeout and then falling back, until one of the configured transports succeeds. The WebSocketTransport can be configured with any Spring Framework WebSocketClient implementation (currently JSR-356 or Jetty 9). The XhrTransport currently has a RestTemplate-based and a Jetty HttpClient-based implementations. To use those to simulate a large number of users be sure to configure Jetty's HttpClient executor and maxConnectionsPerDestination to high numbers. The same is true for whichever underlying HTTP library is used with the RestTemplate (e.g. maxConnPerRoute and maxConnTotal in Apache HttpComponents). Issue: SPR-10797 --- build.gradle | 1 + .../web/socket/CloseStatus.java | 2 +- .../client/AbstractWebSocketClient.java | 11 +- .../messaging/StompSubProtocolHandler.java | 10 +- .../client/AbstractClientSockJsSession.java | 338 +++++++++++++++ .../sockjs/client/AbstractXhrTransport.java | 163 ++++++++ .../client/DefaultTransportRequest.java | 238 +++++++++++ .../socket/sockjs/client/InfoReceiver.java | 24 ++ .../sockjs/client/JettyXhrTransport.java | 252 +++++++++++ .../client/RestTemplateXhrTransport.java | 265 ++++++++++++ .../socket/sockjs/client/SockJsClient.java | 259 ++++++++++++ .../socket/sockjs/client/SockJsUrlInfo.java | 115 +++++ .../web/socket/sockjs/client/Transport.java | 40 ++ .../sockjs/client/TransportRequest.java | 53 +++ .../client/WebSocketClientSockJsSession.java | 136 ++++++ .../sockjs/client/WebSocketTransport.java | 129 ++++++ .../sockjs/client/XhrClientSockJsSession.java | 111 +++++ .../socket/sockjs/client/XhrTransport.java | 40 ++ .../socket/sockjs/client/package-info.java | 22 + .../web/socket/sockjs/frame/SockJsFrame.java | 2 +- .../TransportHandlingSockJsService.java | 2 +- .../web/socket/JettyWebSocketTestServer.java | 1 + .../AbstractSockJsIntegrationTests.java | 394 ++++++++++++++++++ .../client/ClientSockJsSessionTests.java | 280 +++++++++++++ .../client/DefaultTransportRequestTests.java | 139 ++++++ .../client/JettySockJsIntegrationTests.java | 101 +++++ .../client/RestTemplateXhrTransportTests.java | 228 ++++++++++ .../sockjs/client/SockJsClientTests.java | 137 ++++++ .../sockjs/client/SockJsUrlInfoTests.java | 90 ++++ .../socket/sockjs/client/TestTransport.java | 106 +++++ .../sockjs/client/XhrTransportTests.java | 155 +++++++ .../src/test/resources/log4j.properties | 4 +- 32 files changed, 3835 insertions(+), 13 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractClientSockJsSession.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfo.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/Transport.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketClientSockJsSession.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketTransport.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/package-info.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfoTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java diff --git a/build.gradle b/build.gradle index 6579189eef..14c58a4df1 100644 --- a/build.gradle +++ b/build.gradle @@ -670,6 +670,7 @@ project("spring-websocket") { exclude group: "javax.servlet", module: "javax.servlet" } optional("org.eclipse.jetty.websocket:websocket-client:${jettyVersion}") + optional("org.eclipse.jetty:jetty-client:${jettyVersion}") optional("io.undertow:undertow-core:1.0.15.Final") optional("io.undertow:undertow-servlet:1.0.15.Final") { exclude group: "org.jboss.spec.javax.servlet", module: "jboss-servlet-api_3.1_spec" diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java b/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java index 364a9597b7..9733e7f2a5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java @@ -217,7 +217,7 @@ public final class CloseStatus { @Override public String toString() { - return "CloseStatus [code=" + this.code + ", reason=" + this.reason + "]"; + return "CloseStatus[code=" + this.code + ", reason=" + this.reason + "]"; } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java index ea08db729a..dbec02cd99 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java @@ -73,10 +73,7 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { WebSocketHttpHeaders headers, URI uri) { Assert.notNull(webSocketHandler, "webSocketHandler must not be null"); - Assert.notNull(uri, "uri must not be null"); - - String scheme = uri.getScheme(); - Assert.isTrue(((scheme != null) && ("ws".equals(scheme) || "wss".equals(scheme))), "Invalid scheme: " + scheme); + assertUri(uri); if (logger.isDebugEnabled()) { logger.debug("Connecting to " + uri); @@ -101,6 +98,12 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { Collections.emptyMap()); } + protected void assertUri(URI uri) { + Assert.notNull(uri, "uri must not be null"); + String scheme = uri.getScheme(); + Assert.isTrue(scheme != null && ("ws".equals(scheme) || "wss".equals(scheme)), "Invalid scheme: " + scheme); + } + /** * Perform the actual handshake to establish a connection to the server. * diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index c82ebf8928..417955a47d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -194,7 +194,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } } catch (Throwable ex) { - logger.error("Failed to parse WebSocket message to STOMP frame(s)", ex); + logger.error("Failed to parse WebSocket message to STOMP." + + "Sending STOMP ERROR to client, sessionId=" + session.getId(), ex); sendErrorMessage(session, ex); return; } @@ -232,7 +233,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } } catch (Throwable ex) { - logger.error("Terminating STOMP session due to failure to send message", ex); + logger.error("Parsed STOMP message but could not send it to to message channel. " + + "Sending STOMP ERROR to client, sessionId=" + session.getId(), ex); sendErrorMessage(session, ex); } } @@ -248,7 +250,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } protected void sendErrorMessage(WebSocketSession session, Throwable error) { - StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.ERROR); headerAccessor.setMessage(error.getMessage()); byte[] bytes = this.stompEncoder.encode(headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD); @@ -331,7 +332,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE throw ex; } catch (Throwable ex) { - sendErrorMessage(session, ex); + logger.error("Failed to send WebSocket message to client, sessionId=" + session.getId(), ex); + command = StompCommand.ERROR; } finally { if (StompCommand.ERROR.equals(command)) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractClientSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractClientSockJsSession.java new file mode 100644 index 0000000000..16217fa03c --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractClientSockJsSession.java @@ -0,0 +1,338 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; +import org.springframework.web.socket.sockjs.frame.SockJsFrameType; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; + +import java.io.IOException; +import java.net.URI; +import java.security.Principal; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Base class for SockJS client implementations of {@link WebSocketSession}. + * Provides processing of incoming SockJS message frames and delegates lifecycle + * events and messages to the (application) {@link WebSocketHandler}. + * Sub-classes implement actual send as well as disconnect logic. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public abstract class AbstractClientSockJsSession implements WebSocketSession { + + protected final Log logger = LogFactory.getLog(getClass()); + + + private final TransportRequest request; + + private final WebSocketHandler webSocketHandler; + + private final SettableListenableFuture connectFuture; + + + private final Map attributes = new ConcurrentHashMap(); + + private volatile State state = State.NEW; + + private volatile CloseStatus closeStatus; + + + protected AbstractClientSockJsSession(TransportRequest request, WebSocketHandler handler, + SettableListenableFuture connectFuture) { + + Assert.notNull(request, "'request' is required"); + Assert.notNull(handler, "'handler' is required"); + Assert.notNull(connectFuture, "'connectFuture' is required"); + this.request = request; + this.webSocketHandler = handler; + this.connectFuture = connectFuture; + } + + + @Override + public String getId() { + return this.request.getSockJsUrlInfo().getSessionId(); + } + + @Override + public URI getUri() { + return this.request.getSockJsUrlInfo().getSockJsUrl(); + } + + @Override + public HttpHeaders getHandshakeHeaders() { + return this.request.getHandshakeHeaders(); + } + + @Override + public Map getAttributes() { + return this.attributes; + } + + @Override + public Principal getPrincipal() { + return this.request.getUser(); + } + + public SockJsMessageCodec getMessageCodec() { + return this.request.getMessageCodec(); + } + + public WebSocketHandler getWebSocketHandler() { + return this.webSocketHandler; + } + + /** + * Return a timeout cleanup task to invoke if the SockJS sessions is not + * fully established within the retransmission timeout period calculated in + * {@code SockJsRequest} based on the duration of the initial SockJS "Info" + * request. + */ + Runnable getTimeoutTask() { + return new Runnable() { + @Override + public void run() { + closeInternal(new CloseStatus(2007, "Transport timed out")); + } + }; + } + + @Override + public boolean isOpen() { + return State.OPEN.equals(this.state); + } + + public boolean isDisconnected() { + return (State.CLOSING.equals(this.state) || State.CLOSED.equals(this.state)); + } + + @Override + public final void sendMessage(WebSocketMessage message) throws IOException { + Assert.state(State.OPEN.equals(this.state), this + " is not open, current state=" + this.state); + Assert.isInstanceOf(TextMessage.class, message, this + " supports text messages only."); + String payload = ((TextMessage) message).getPayload(); + payload = getMessageCodec().encode(new String[] { payload }); + payload = payload.substring(1); // the client-side doesn't need message framing (letter "a") + message = new TextMessage(payload); + if (logger.isTraceEnabled()) { + logger.trace("Sending message " + message + " in " + this); + } + sendInternal((TextMessage) message); + } + + protected abstract void sendInternal(TextMessage textMessage) throws IOException; + + @Override + public final void close() throws IOException { + close(CloseStatus.NORMAL); + } + + @Override + public final void close(CloseStatus status) { + Assert.isTrue(status != null && isUserSetStatus(status), "Invalid close status: " + status); + if (logger.isInfoEnabled()) { + logger.info("Closing session with " + status + " in " + this); + } + closeInternal(status); + } + + private boolean isUserSetStatus(CloseStatus status) { + return (status.getCode() == 1000 || (status.getCode() >= 3000 && status.getCode() <= 4999)); + } + + protected void closeInternal(CloseStatus status) { + if (this.state == null) { + logger.warn("Ignoring close since connect() was never invoked"); + return; + } + if (State.CLOSING.equals(this.state) || State.CLOSED.equals(this.state)) { + logger.debug("Ignoring close (already closing or closed), current state=" + this.state); + return; + } + this.state = State.CLOSING; + this.closeStatus = status; + try { + disconnect(status); + } + catch (Throwable ex) { + if (logger.isErrorEnabled()) { + logger.error("Failed to close " + this, ex); + } + } + } + + protected abstract void disconnect(CloseStatus status) throws IOException; + + public void handleFrame(String payload) { + SockJsFrame frame = new SockJsFrame(payload); + if (SockJsFrameType.OPEN.equals(frame.getType())) { + handleOpenFrame(); + } + else if (SockJsFrameType.MESSAGE.equals(frame.getType())) { + handleMessageFrame(frame); + } + else if (SockJsFrameType.CLOSE.equals(frame.getType())) { + handleCloseFrame(frame); + } + else if (SockJsFrameType.HEARTBEAT.equals(frame.getType())) { + if (logger.isTraceEnabled()) { + logger.trace("Received heartbeat in " + this); + } + } + else { + // should never happen + throw new IllegalStateException("Unknown SockJS frame type " + frame + " in " + this); + } + } + + private void handleOpenFrame() { + if (logger.isInfoEnabled()) { + logger.info("Processing SockJS open frame in " + this); + } + if (State.NEW.equals(state)) { + this.state = State.OPEN; + try { + this.webSocketHandler.afterConnectionEstablished(this); + this.connectFuture.set(this); + } + catch (Throwable ex) { + if (logger.isErrorEnabled()) { + Class type = this.webSocketHandler.getClass(); + logger.error(type + ".afterConnectionEstablished threw exception in " + this, ex); + } + } + } + else { + if (logger.isDebugEnabled()) { + logger.debug("Open frame received in " + getId() + " but we're not" + + "connecting (current state=" + this.state + "). The server might " + + "have been restarted and lost track of the session."); + } + closeInternal(new CloseStatus(1006, "Server lost session")); + } + } + + private void handleMessageFrame(SockJsFrame frame) { + if (!isOpen()) { + if (logger.isWarnEnabled()) { + logger.warn("Ignoring received message due to state=" + this.state + " in " + this); + } + return; + } + String[] messages; + try { + messages = getMessageCodec().decode(frame.getFrameData()); + } + catch (IOException ex) { + if (logger.isErrorEnabled()) { + logger.error("Failed to decode data for SockJS \"message\" frame: " + frame + " in " + this, ex); + } + closeInternal(CloseStatus.BAD_DATA); + return; + } + if (logger.isTraceEnabled()) { + logger.trace("Processing SockJS message frame " + frame.getContent() + " in " + this); + } + for (String message : messages) { + try { + if (isOpen()) { + this.webSocketHandler.handleMessage(this, new TextMessage(message)); + } + } + catch (Throwable ex) { + Class type = this.webSocketHandler.getClass(); + logger.error(type + ".handleMessage threw an exception on " + frame + " in " + this, ex); + } + } + } + + private void handleCloseFrame(SockJsFrame frame) { + CloseStatus closeStatus = CloseStatus.NO_STATUS_CODE; + try { + String[] data = getMessageCodec().decode(frame.getFrameData()); + if (data.length == 2) { + closeStatus = new CloseStatus(Integer.valueOf(data[0]), data[1]); + } + if (logger.isInfoEnabled()) { + logger.info("Processing SockJS close frame with " + closeStatus + " in " + this); + } + } + catch (IOException ex) { + if (logger.isErrorEnabled()) { + logger.error("Failed to decode data for " + frame + " in " + this, ex); + } + } + closeInternal(closeStatus); + } + + public void handleTransportError(Throwable error) { + try { + if (logger.isErrorEnabled()) { + logger.error("Transport error in " + this, error); + } + this.webSocketHandler.handleTransportError(this, error); + } + catch (Exception ex) { + Class type = this.webSocketHandler.getClass(); + if (logger.isErrorEnabled()) { + logger.error(type + ".handleTransportError threw an exception", ex); + } + } + } + + public void afterTransportClosed(CloseStatus closeStatus) { + this.closeStatus = (this.closeStatus != null ? this.closeStatus : closeStatus); + Assert.state(this.closeStatus != null, "CloseStatus not available"); + + if (logger.isInfoEnabled()) { + logger.info("Transport closed with " + this.closeStatus + " in " + this); + } + + this.state = State.CLOSED; + try { + this.webSocketHandler.afterConnectionClosed(this, this.closeStatus); + } + catch (Exception ex) { + if (logger.isErrorEnabled()) { + Class type = this.webSocketHandler.getClass(); + logger.error(type + ".afterConnectionClosed threw an exception", ex); + } + } + } + + @Override + public String toString() { + return getClass().getSimpleName() + "[id='" + getId() + ", url=" + getUri() + "]"; + } + + + private enum State { NEW, OPEN, CLOSING, CLOSED } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java new file mode 100644 index 0000000000..9ef2218528 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; + +import java.net.URI; + +/** + * Abstract base class for XHR transport implementations to extend. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public abstract class AbstractXhrTransport implements XhrTransport { + + protected static final String PRELUDE; + + static { + byte[] bytes = new byte[2048]; + for (int i = 0; i < bytes.length; i++) { + bytes[i] = 'h'; + } + PRELUDE = new String(bytes, SockJsFrame.CHARSET); + } + + protected Log logger = LogFactory.getLog(getClass()); + + private boolean xhrStreamingDisabled; + + private HttpHeaders requestHeaders = new HttpHeaders(); + + private HttpHeaders xhrSendRequestHeaders = new HttpHeaders(); + + + /** + * Whether to attempt to connect with "xhr_streaming" first before trying + * with "xhr" next, see {@link XhrTransport#isXhrStreamingDisabled()}. + * + *

By default this property is set to {@code false} which means both + * "xhr_streaming" and "xhr" will be tried. + */ + public void setXhrStreamingDisabled(boolean disabled) { + this.xhrStreamingDisabled = disabled; + } + + public boolean isXhrStreamingDisabled() { + return this.xhrStreamingDisabled; + } + + /** + * Configure headers to be added to every executed HTTP request. + * @param requestHeaders the headers to add to requests + */ + public void setRequestHeaders(HttpHeaders requestHeaders) { + this.requestHeaders.clear(); + this.xhrSendRequestHeaders.clear(); + if (requestHeaders != null) { + this.requestHeaders.putAll(requestHeaders); + this.xhrSendRequestHeaders.putAll(requestHeaders); + this.xhrSendRequestHeaders.setContentType(MediaType.APPLICATION_JSON); + } + } + + public HttpHeaders getRequestHeaders() { + return this.requestHeaders; + } + + @Override + public String executeInfoRequest(URI infoUrl) { + if (logger.isDebugEnabled()) { + logger.debug("Executing SockJS Info request, url=" + infoUrl); + } + ResponseEntity response = executeInfoRequestInternal(infoUrl); + if (response.getStatusCode() != HttpStatus.OK) { + if (logger.isErrorEnabled()) { + logger.error("SockJS Info request (url=" + infoUrl + ") failed: " + response); + } + throw new HttpServerErrorException(response.getStatusCode()); + } + if (logger.isDebugEnabled()) { + logger.debug("SockJS Info request (url=" + infoUrl + ") response: " + response); + } + return response.getBody(); + } + + protected abstract ResponseEntity executeInfoRequestInternal(URI infoUrl); + + @Override + public void executeSendRequest(URI url, TextMessage message) { + if (logger.isDebugEnabled()) { + logger.debug("Starting XHR send, url=" + url); + } + ResponseEntity response = executeSendRequestInternal(url, this.xhrSendRequestHeaders, message); + if (response.getStatusCode() != HttpStatus.NO_CONTENT) { + if (logger.isErrorEnabled()) { + logger.error("XHR send request (url=" + url + ") failed: " + response); + } + throw new HttpServerErrorException(response.getStatusCode()); + } + if (logger.isDebugEnabled()) { + logger.debug("XHR send request (url=" + url + ") response: " + response); + } + } + + protected abstract ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message); + + @Override + public ListenableFuture connect(TransportRequest request, WebSocketHandler handler) { + SettableListenableFuture connectFuture = new SettableListenableFuture(); + XhrClientSockJsSession session = new XhrClientSockJsSession(request, handler, this, connectFuture); + request.addTimeoutTask(session.getTimeoutTask()); + + URI receiveUrl = request.getTransportUrl(); + if (logger.isDebugEnabled()) { + logger.debug("Opening XHR session, receive url=" + receiveUrl); + } + + HttpHeaders handshakeHeaders = new HttpHeaders(); + handshakeHeaders.putAll(request.getHandshakeHeaders()); + handshakeHeaders.putAll(getRequestHeaders()); + + connectInternal(request, handler, receiveUrl, handshakeHeaders, session, connectFuture); + return connectFuture; + } + + protected abstract void connectInternal(TransportRequest request, WebSocketHandler handler, + URI receiveUrl, HttpHeaders handshakeHeaders, XhrClientSockJsSession session, + SettableListenableFuture connectFuture); + + + @Override + public String toString() { + return getClass().getSimpleName(); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java new file mode 100644 index 0000000000..fd1e630f77 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java @@ -0,0 +1,238 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.SockJsTransportFailureException; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.URI; +import java.security.Principal; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A default implementation of + * {@link org.springframework.web.socket.sockjs.client.TransportRequest + * TransportRequest}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +class DefaultTransportRequest implements TransportRequest { + + private static Log logger = LogFactory.getLog(DefaultTransportRequest.class); + + + private final SockJsUrlInfo sockJsUrlInfo; + + private final HttpHeaders handshakeHeaders; + + private final Transport transport; + + private final TransportType serverTransportType; + + private SockJsMessageCodec codec; + + private Principal user; + + private long timeoutValue; + + private TaskScheduler timeoutScheduler; + + private final List timeoutTasks = new ArrayList(); + + private DefaultTransportRequest fallbackRequest; + + + public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo, HttpHeaders handshakeHeaders, + Transport transport, TransportType serverTransportType, SockJsMessageCodec codec) { + + Assert.notNull(sockJsUrlInfo, "'sockJsUrlInfo' is required"); + Assert.notNull(transport, "'transport' is required"); + Assert.notNull(serverTransportType, "'transportType' is required"); + Assert.notNull(codec, "'codec' is required"); + this.sockJsUrlInfo = sockJsUrlInfo; + this.handshakeHeaders = (handshakeHeaders != null ? handshakeHeaders : new HttpHeaders()); + this.transport = transport; + this.serverTransportType = serverTransportType; + this.codec = codec; + } + + + @Override + public SockJsUrlInfo getSockJsUrlInfo() { + return this.sockJsUrlInfo; + } + + @Override + public HttpHeaders getHandshakeHeaders() { + return this.handshakeHeaders; + } + + @Override + public URI getTransportUrl() { + return this.sockJsUrlInfo.getTransportUrl(this.serverTransportType); + } + + public void setUser(Principal user) { + this.user = user; + } + + @Override + public Principal getUser() { + return this.user; + } + + @Override + public SockJsMessageCodec getMessageCodec() { + return this.codec; + } + + public void setTimeoutValue(long timeoutValue) { + this.timeoutValue = timeoutValue; + } + + public void setTimeoutScheduler(TaskScheduler scheduler) { + this.timeoutScheduler = scheduler; + } + + @Override + public void addTimeoutTask(Runnable runnable) { + this.timeoutTasks.add(runnable); + } + + public void setFallbackRequest(DefaultTransportRequest fallbackRequest) { + this.fallbackRequest = fallbackRequest; + } + + + public void connect(WebSocketHandler handler, SettableListenableFuture future) { + if (logger.isDebugEnabled()) { + logger.debug("Starting " + this); + } + ConnectCallback connectCallback = new ConnectCallback(handler, future); + scheduleConnectTimeoutTask(connectCallback); + this.transport.connect(this, handler).addCallback(connectCallback); + } + + + private void scheduleConnectTimeoutTask(ConnectCallback connectHandler) { + if (this.timeoutScheduler != null) { + if (logger.isDebugEnabled()) { + logger.debug("Scheduling connect to time out after " + this.timeoutValue + " milliseconds"); + } + Date timeoutDate = new Date(System.currentTimeMillis() + this.timeoutValue); + this.timeoutScheduler.schedule(connectHandler, timeoutDate); + } + else if (logger.isDebugEnabled()) { + logger.debug("Connect timeout task not scheduled. Is SockJsClient configured with a TaskScheduler?"); + } + } + + + @Override + public String toString() { + return "TransportRequest[url=" + getTransportUrl() + "]"; + } + + + /** + * Updates the given (global) future based success or failure to connect for + * the entire SockJS request regardless of which transport actually managed + * to connect. Also implements {@code Runnable} to handle a scheduled timeout + * callback. + */ + private class ConnectCallback implements ListenableFutureCallback, Runnable { + + private final WebSocketHandler handler; + + private final SettableListenableFuture future; + + private final AtomicBoolean handled = new AtomicBoolean(false); + + + public ConnectCallback(WebSocketHandler handler, SettableListenableFuture future) { + this.handler = handler; + this.future = future; + } + + + @Override + public void onSuccess(WebSocketSession session) { + if (this.handled.compareAndSet(false, true)) { + this.future.set(session); + } + else { + logger.error("Connect success/failure already handled for " + DefaultTransportRequest.this); + } + } + + @Override + public void onFailure(Throwable failure) { + handleFailure(failure, false); + } + + @Override + public void run() { + handleFailure(null, true); + } + + private void handleFailure(Throwable failure, boolean isTimeoutFailure) { + if (this.handled.compareAndSet(false, true)) { + if (isTimeoutFailure) { + String message = "Connect timed out for " + DefaultTransportRequest.this; + logger.error(message); + failure = new SockJsTransportFailureException(message, getSockJsUrlInfo().getSessionId(), null); + } + if (fallbackRequest != null) { + logger.error(DefaultTransportRequest.this + " failed. Falling back on next transport.", failure); + fallbackRequest.connect(this.handler, this.future); + } + else { + logger.error("No more fallback transports after " + DefaultTransportRequest.this, failure); + this.future.setException(failure); + } + if (isTimeoutFailure) { + try { + for (Runnable runnable : timeoutTasks) { + runnable.run(); + } + } + catch (Throwable ex) { + logger.error("Transport failed to run timeout tasks for " + DefaultTransportRequest.this, ex); + } + } + } + else { + logger.error("Connect success/failure events already took place for " + + DefaultTransportRequest.this + ". Ignoring this additional failure event.", failure); + } + } + } +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java new file mode 100644 index 0000000000..ae8ba7bf01 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java @@ -0,0 +1,24 @@ +package org.springframework.web.socket.sockjs.client; + +import java.net.URI; + +/** + * A simple contract for executing the SockJS "Info" request before the SockJS + * session starts. The request is used to check server capabilities such as + * whether it permits use of the WebSocket transport. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public interface InfoReceiver { + + /** + * Perform an HTTP request to the SockJS "Info" URL. + * and return the resulting JSON response content, or raise an exception. + * + * @param infoUrl the URL to obtain SockJS server information from + * @return the body of the response + */ + String executeInfoRequest(URI infoUrl); + +} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java new file mode 100644 index 0000000000..a4f8e23870 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java @@ -0,0 +1,252 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.client.api.ContentResponse; +import org.eclipse.jetty.client.api.Request; +import org.eclipse.jetty.client.api.Response; +import org.eclipse.jetty.client.util.StringContentProvider; +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.http.HttpMethod; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.SockJsException; +import org.springframework.web.socket.sockjs.SockJsTransportFailureException; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; + +import java.io.ByteArrayOutputStream; +import java.net.URI; +import java.nio.ByteBuffer; +import java.util.Enumeration; + + +/** + * An XHR transport based on Jetty's {@link org.eclipse.jetty.client.HttpClient}. + * + *

When used for testing purposes (e.g. load testing) the {@code HttpClient} + * properties must be set to allow a larger than usual number of connections and + * threads. For example: + * + *

+ * HttpClient httpClient = new HttpClient();
+ * httpClient.setMaxConnectionsPerDestination(1000);
+ * httpClient.setExecutor(new QueuedThreadPool(500));
+ * 
+ * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class JettyXhrTransport extends AbstractXhrTransport implements XhrTransport { + + private final HttpClient httpClient; + + + public JettyXhrTransport(HttpClient httpClient) { + Assert.notNull(httpClient, "'httpClient' is required"); + this.httpClient = httpClient; + } + + + public HttpClient getHttpClient() { + return this.httpClient; + } + + @Override + protected ResponseEntity executeInfoRequestInternal(URI infoUrl) { + return executeRequest(infoUrl, HttpMethod.GET, getRequestHeaders(), null); + } + + @Override + public ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) { + return executeRequest(url, HttpMethod.POST, headers, message.getPayload()); + } + + protected ResponseEntity executeRequest(URI url, HttpMethod method, HttpHeaders headers, String body) { + Request httpRequest = this.httpClient.newRequest(url).method(method); + addHttpHeaders(httpRequest, headers); + if (body != null) { + httpRequest.content(new StringContentProvider(body)); + } + ContentResponse response; + try { + response = httpRequest.send(); + } + catch (Exception ex) { + throw new SockJsTransportFailureException("Failed to execute request to " + url, null, ex); + } + HttpStatus status = HttpStatus.valueOf(response.getStatus()); + HttpHeaders responseHeaders = toHttpHeaders(response.getHeaders()); + return (response.getContent() != null ? + new ResponseEntity(response.getContentAsString(), responseHeaders, status) : + new ResponseEntity(responseHeaders, status)); + } + + private static void addHttpHeaders(Request request, HttpHeaders headers) { + for (String name : headers.keySet()) { + for (String value : headers.get(name)) { + request.header(name, value); + } + } + } + + private static HttpHeaders toHttpHeaders(HttpFields httpFields) { + HttpHeaders responseHeaders = new HttpHeaders(); + Enumeration names = httpFields.getFieldNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + Enumeration values = httpFields.getValues(name); + while (values.hasMoreElements()) { + String value = values.nextElement(); + responseHeaders.add(name, value); + } + } + return responseHeaders; + } + + @Override + protected void connectInternal(TransportRequest request, WebSocketHandler handler, + URI url, HttpHeaders handshakeHeaders, XhrClientSockJsSession session, + SettableListenableFuture connectFuture) { + + SockJsResponseListener listener = new SockJsResponseListener(url, getRequestHeaders(), session, connectFuture); + executeReceiveRequest(url, handshakeHeaders, listener); + } + + private void executeReceiveRequest(URI url, HttpHeaders headers, SockJsResponseListener listener) { + if (logger.isDebugEnabled()) { + logger.debug("Starting XHR receive request, url=" + url); + } + Request httpRequest = this.httpClient.newRequest(url).method(HttpMethod.POST); + addHttpHeaders(httpRequest, headers); + httpRequest.send(listener); + } + + + /** + * Splits the body of an HTTP response into SockJS frames and delegates those + * to an {@link XhrClientSockJsSession}. + */ + private class SockJsResponseListener extends Response.Listener.Adapter { + + private final URI transportUrl; + + private final HttpHeaders receiveHeaders; + + private final XhrClientSockJsSession sockJsSession; + + private final SettableListenableFuture connectFuture; + + private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + + + public SockJsResponseListener(URI url, HttpHeaders headers, XhrClientSockJsSession sockJsSession, + SettableListenableFuture connectFuture) { + + this.transportUrl = url; + this.receiveHeaders = headers; + this.connectFuture = connectFuture; + this.sockJsSession = sockJsSession; + } + + + @Override + public void onBegin(Response response) { + if (response.getStatus() != 200) { + HttpStatus status = HttpStatus.valueOf(response.getStatus()); + response.abort(new HttpServerErrorException(status, "Unexpected XHR receive status")); + } + } + + @Override + public void onHeaders(Response response) { + if (logger.isDebugEnabled()) { + // Convert to HttpHeaders to avoid "\n" + logger.debug("XHR receive headers: " + toHttpHeaders(response.getHeaders())); + } + } + + @Override + public void onContent(Response response, ByteBuffer buffer) { + while (true) { + if (this.sockJsSession.isDisconnected()) { + if (logger.isDebugEnabled()) { + logger.debug("SockJS sockJsSession closed. Closing ClientHttpResponse."); + } + response.abort(new SockJsException("Session closed.", this.sockJsSession.getId(), null)); + return; + } + if (buffer.remaining() == 0) { + break; + } + int b = buffer.get(); + if (b == '\n') { + handleFrame(); + } + else { + this.outputStream.write(b); + } + } + } + + private void handleFrame() { + byte[] bytes = this.outputStream.toByteArray(); + this.outputStream.reset(); + String content = new String(bytes, SockJsFrame.CHARSET); + if (logger.isTraceEnabled()) { + logger.trace("XHR content received: " + content); + } + if (!PRELUDE.equals(content)) { + this.sockJsSession.handleFrame(new String(bytes, SockJsFrame.CHARSET)); + } + } + + @Override + public void onSuccess(Response response) { + if (this.outputStream.size() > 0) { + handleFrame(); + } + if (logger.isDebugEnabled()) { + logger.debug("XHR receive request completed."); + } + executeReceiveRequest(this.transportUrl, this.receiveHeaders, this); + } + + @Override + public void onFailure(Response response, Throwable failure) { + if (connectFuture.setException(failure)) { + return; + } + if (this.sockJsSession.isDisconnected()) { + this.sockJsSession.afterTransportClosed(null); + } + else { + this.sockJsSession.handleTransportError(failure); + this.sockJsSession.afterTransportClosed(new CloseStatus(1006, failure.getMessage())); + } + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java new file mode 100644 index 0000000000..1f25be4009 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java @@ -0,0 +1,265 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.core.task.TaskExecutor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.Assert; +import org.springframework.util.StreamUtils; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.client.RequestCallback; +import org.springframework.web.client.ResponseExtractor; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; + +/** + * An {@code XhrTransport} implementation that uses a + * {@link org.springframework.web.client.RestTemplate RestTemplate}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class RestTemplateXhrTransport extends AbstractXhrTransport implements XhrTransport { + + private final RestOperations restTemplate; + + private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); + + + public RestTemplateXhrTransport() { + this(new RestTemplate()); + } + + public RestTemplateXhrTransport(RestOperations restTemplate) { + Assert.notNull(restTemplate, "'restTemplate' is required"); + this.restTemplate = restTemplate; + } + + + /** + * Return the configured {@code RestTemplate}. + */ + public RestOperations getRestTemplate() { + return this.restTemplate; + } + + /** + * Configure the {@code TaskExecutor} to use to execute XHR receive requests. + * + *

By default {@link org.springframework.core.task.SimpleAsyncTaskExecutor + * SimpleAsyncTaskExecutor} is configured which creates a new thread every + * time the transports connects. + * + * @param taskExecutor the task executor, cannot be {@code null} + */ + public void setTaskExecutor(TaskExecutor taskExecutor) { + Assert.notNull(this.taskExecutor); + this.taskExecutor = taskExecutor; + } + + /** + * Return the configured {@code TaskExecutor}. + */ + public TaskExecutor getTaskExecutor() { + return this.taskExecutor; + } + + + @Override + public ResponseEntity executeInfoRequestInternal(URI infoUrl) { + RequestCallback requestCallback = new XhrRequestCallback(getRequestHeaders()); + return this.restTemplate.execute(infoUrl, HttpMethod.GET, requestCallback, textExtractor); + } + + @Override + public ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) { + RequestCallback requestCallback = new XhrRequestCallback(headers, message.getPayload()); + return this.restTemplate.execute(url, HttpMethod.POST, requestCallback, textExtractor); + } + + @Override + protected void connectInternal(final TransportRequest request, final WebSocketHandler handler, + final URI receiveUrl, final HttpHeaders handshakeHeaders, final XhrClientSockJsSession session, + final SettableListenableFuture connectFuture) { + + getTaskExecutor().execute(new Runnable() { + @Override + public void run() { + XhrRequestCallback requestCallback = new XhrRequestCallback(handshakeHeaders); + XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(getRequestHeaders()); + XhrReceiveExtractor responseExtractor = new XhrReceiveExtractor(session); + while (true) { + if (session.isDisconnected()) { + session.afterTransportClosed(null); + break; + } + try { + if (logger.isDebugEnabled()) { + logger.debug("Starting XHR receive request, url=" + receiveUrl); + } + getRestTemplate().execute(receiveUrl, HttpMethod.POST, requestCallback, responseExtractor); + requestCallback = requestCallbackAfterHandshake; + } + catch (Throwable ex) { + if (!connectFuture.isDone()) { + connectFuture.setException(ex); + } + else { + session.handleTransportError(ex); + session.afterTransportClosed(new CloseStatus(1006, ex.getMessage())); + } + break; + } + } + } + }); + } + + + /** + * A RequestCallback to add the headers and (optionally) String content. + */ + private static class XhrRequestCallback implements RequestCallback { + + private final HttpHeaders headers; + + private final String body; + + + public XhrRequestCallback(HttpHeaders headers) { + this(headers, null); + } + + public XhrRequestCallback(HttpHeaders headers, String body) { + this.headers = headers; + this.body = body; + } + + + @Override + public void doWithRequest(ClientHttpRequest request) throws IOException { + if (this.headers != null) { + request.getHeaders().putAll(this.headers); + } + if (this.body != null) { + StreamUtils.copy(this.body, SockJsFrame.CHARSET, request.getBody()); + } + } + } + + /** + * A simple ResponseExtractor that reads the body into a String. + */ + private final static ResponseExtractor> textExtractor = + new ResponseExtractor>() { + + @Override + public ResponseEntity extractData(ClientHttpResponse response) throws IOException { + if (response.getBody() == null) { + return new ResponseEntity(response.getHeaders(), response.getStatusCode()); + } + else { + String body = StreamUtils.copyToString(response.getBody(), SockJsFrame.CHARSET); + return new ResponseEntity(body, response.getHeaders(), response.getStatusCode()); + } + } + }; + + /** + * Splits the body of an HTTP response into SockJS frames and delegates those + * to an {@link XhrClientSockJsSession}. + */ + private class XhrReceiveExtractor implements ResponseExtractor { + + private final XhrClientSockJsSession sockJsSession; + + + public XhrReceiveExtractor(XhrClientSockJsSession sockJsSession) { + this.sockJsSession = sockJsSession; + } + + + @Override + public Object extractData(ClientHttpResponse response) throws IOException { + if (!HttpStatus.OK.equals(response.getStatusCode())) { + throw new HttpServerErrorException(response.getStatusCode()); + } + if (logger.isDebugEnabled()) { + logger.debug("XHR receive headers: " + response.getHeaders()); + } + InputStream is = response.getBody(); + ByteArrayOutputStream os = new ByteArrayOutputStream(); + while (true) { + if (this.sockJsSession.isDisconnected()) { + if (logger.isDebugEnabled()) { + logger.debug("SockJS sockJsSession closed. Closing ClientHttpResponse."); + } + response.close(); + break; + } + int b = is.read(); + if (b == -1) { + if (os.size() > 0) { + handleFrame(os); + } + if (logger.isDebugEnabled()) { + logger.debug("XHR receive completed"); + } + break; + } + if (b == '\n') { + handleFrame(os); + } + else { + os.write(b); + } + } + return null; + } + + private void handleFrame(ByteArrayOutputStream os) { + byte[] bytes = os.toByteArray(); + os.reset(); + String content = new String(bytes, SockJsFrame.CHARSET); + if (logger.isTraceEnabled()) { + logger.trace("XHR receive content: " + content); + } + if (!PRELUDE.equals(content)) { + this.sockJsSession.handleFrame(new String(bytes, SockJsFrame.CHARSET)); + } + } + } + +} + diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java new file mode 100644 index 0000000000..5d8bef8cfe --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java @@ -0,0 +1,259 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.client.AbstractWebSocketClient; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.URI; +import java.security.Principal; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A SockJS implementation of + * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient} + * with HTTP-based fallback alternative simulating a WebSocket interaction. + * + * @author Rossen Stoyanchev + * @since 4.1 + * + * @see http://sockjs.org + * @see org.springframework.web.socket.sockjs.client.Transport + */ +public class SockJsClient extends AbstractWebSocketClient { + + private static final boolean jackson2Present = ClassUtils.isPresent( + "com.fasterxml.jackson.databind.ObjectMapper", SockJsClient.class.getClassLoader()); + + private static final Log logger = LogFactory.getLog(SockJsClient.class); + + + private final List transports; + + private InfoReceiver infoReceiver; + + private SockJsMessageCodec messageCodec; + + private TaskScheduler taskScheduler; + + private final Map infoCache = new ConcurrentHashMap(); + + + /** + * Create a {@code SockJsClient} with the given transports. + * @param transports the transports to use + */ + public SockJsClient(List transports) { + Assert.notEmpty(transports, "No transports provided"); + this.transports = new ArrayList(transports); + this.infoReceiver = initInfoReceiver(transports); + if (jackson2Present) { + this.messageCodec = new Jackson2SockJsMessageCodec(); + } + } + + private static InfoReceiver initInfoReceiver(List transports) { + for (Transport transport : transports) { + if (transport instanceof InfoReceiver) { + return ((InfoReceiver) transport); + } + } + return new RestTemplateXhrTransport(); + } + + + /** + * Configure the {@code InfoReceiver} to use to perform the SockJS "Info" + * request before the SockJS session starts. + * + *

By default this is initialized either by looking through the configured + * transports to find the first {@code XhrTransport} or by creating an + * instance of {@code RestTemplateXhrTransport}. + * + * @param infoReceiver the transport to use for the SockJS "Info" request + */ + public void setInfoReceiver(InfoReceiver infoReceiver) { + this.infoReceiver = infoReceiver; + } + + public InfoReceiver getInfoReceiver() { + return this.infoReceiver; + } + + /** + * Set the SockJsMessageCodec to use. + * + *

By default {@link org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec + * Jackson2SockJsMessageCodec} is used if Jackson is on the classpath. + * + * @param messageCodec the message messageCodec to use + */ + public void setMessageCodec(SockJsMessageCodec messageCodec) { + Assert.notNull(messageCodec, "'messageCodec' is required"); + this.messageCodec = messageCodec; + } + + public SockJsMessageCodec getMessageCodec() { + return this.messageCodec; + } + + /** + * Configure a {@code TaskScheduler} for scheduling a connect timeout task + * where the timeout value is calculated based on the duration of the initial + * SockJS info request. Having a connect timeout task is optional but can + * improve the speed with which the client falls back to alternative + * transport options. + * + *

By default no task scheduler is configured in which case it may take + * longer before a fallback transport can be used. + * + * @param taskScheduler the scheduler to use + */ + public void setTaskScheduler(TaskScheduler taskScheduler) { + this.taskScheduler = taskScheduler; + } + + public void clearServerInfoCache() { + this.infoCache.clear(); + } + + @Override + protected void assertUri(URI uri) { + Assert.notNull(uri, "uri must not be null"); + String scheme = uri.getScheme(); + Assert.isTrue(scheme != null && ("ws".equals(scheme) || "wss".equals(scheme) + || "http".equals(scheme) || "https".equals(scheme)), "Invalid scheme: " + scheme); + } + + @Override + protected ListenableFuture doHandshakeInternal(WebSocketHandler handler, + HttpHeaders handshakeHeaders, URI url, List protocols, + List extensions, Map attributes) { + + SettableListenableFuture connectFuture = new SettableListenableFuture(); + try { + SockJsUrlInfo sockJsUrlInfo = new SockJsUrlInfo(url); + ServerInfo serverInfo = getServerInfo(sockJsUrlInfo); + createFallbackChain(sockJsUrlInfo, handshakeHeaders, serverInfo).connect(handler, connectFuture); + } + catch (Throwable exception) { + if (logger.isErrorEnabled()) { + logger.error("Initial SockJS \"Info\" request to server failed, url=" + url, exception); + } + connectFuture.setException(exception); + } + return connectFuture; + } + + private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo) { + URI infoUrl = sockJsUrlInfo.getInfoUrl(); + ServerInfo info = this.infoCache.get(infoUrl); + if (info == null) { + long start = System.currentTimeMillis(); + String response = this.infoReceiver.executeInfoRequest(infoUrl); + long infoRequestTime = System.currentTimeMillis() - start; + info = new ServerInfo(response, infoRequestTime); + this.infoCache.put(infoUrl, info); + } + return info; + } + + private DefaultTransportRequest createFallbackChain(SockJsUrlInfo urlInfo, HttpHeaders headers, ServerInfo serverInfo) { + List requests = new ArrayList(this.transports.size()); + for (Transport transport : this.transports) { + if (transport instanceof XhrTransport) { + XhrTransport xhrTransport = (XhrTransport) transport; + if (!xhrTransport.isXhrStreamingDisabled()) { + addRequest(requests, urlInfo, headers, serverInfo, transport, TransportType.XHR_STREAMING); + } + addRequest(requests, urlInfo, headers, serverInfo, transport, TransportType.XHR); + } + else if (serverInfo.isWebSocketEnabled()) { + addRequest(requests, urlInfo, headers, serverInfo, transport, TransportType.WEBSOCKET); + } + } + Assert.notEmpty(requests, + "0 transports for request to " + urlInfo + " . Configured transports: " + + this.transports + ". SockJS server webSocketEnabled=" + serverInfo.isWebSocketEnabled()); + for (int i = 0; i < requests.size() - 1; i++) { + requests.get(i).setFallbackRequest(requests.get(i + 1)); + } + return requests.get(0); + } + + private void addRequest(List requests, SockJsUrlInfo info, HttpHeaders headers, + ServerInfo serverInfo, Transport transport, TransportType type) { + + DefaultTransportRequest request = new DefaultTransportRequest(info, headers, transport, type, getMessageCodec()); + request.setUser(getUser()); + if (this.taskScheduler != null) { + request.setTimeoutValue(serverInfo.getRetransmissionTimeout()); + request.setTimeoutScheduler(this.taskScheduler); + } + requests.add(request); + } + + /** + * Return the user to associate with the SockJS session and make available via + * {@link org.springframework.web.socket.WebSocketSession#getPrincipal() + * WebSocketSession#getPrincipal()}. + *

By default this method returns {@code null}. + * @return the user to associate with the session, possibly {@code null} + */ + protected Principal getUser() { + return null; + } + + + private static class ServerInfo { + + private final boolean webSocketEnabled; + + private final long responseTime; + + + private ServerInfo(String response, long responseTime) { + this.responseTime = responseTime; + this.webSocketEnabled = !response.matches(".*[\"']websocket[\"']\\s*:\\s*false.*"); + } + + public boolean isWebSocketEnabled() { + return this.webSocketEnabled; + } + + public long getRetransmissionTimeout() { + return (this.responseTime > 100 ? 4 * this.responseTime : this.responseTime + 300); + } + } + +} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfo.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfo.java new file mode 100644 index 0000000000..6530f0a62b --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfo.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.springframework.util.AlternativeJdkIdGenerator; +import org.springframework.util.IdGenerator; +import org.springframework.web.socket.sockjs.transport.TransportType; +import org.springframework.web.util.UriComponentsBuilder; + +import java.net.URI; +import java.util.UUID; + +/** + * Given the base URL to a SockJS server endpoint, also provides methods to + * generate and obtain session and a server id used for construct a transport URL. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class SockJsUrlInfo { + + private static final IdGenerator idGenerator = new AlternativeJdkIdGenerator(); + + + private final URI sockJsUrl; + + private String serverId; + + private String sessionId; + + private UUID uuid; + + + public SockJsUrlInfo(URI sockJsUrl) { + this.sockJsUrl = sockJsUrl; + } + + + public URI getSockJsUrl() { + return this.sockJsUrl; + } + + public String getServerId() { + if (this.serverId == null) { + this.serverId = String.valueOf(Math.abs(getUuid().getMostSignificantBits()) % 1000); + } + return this.serverId; + } + + public String getSessionId() { + if (this.sessionId == null) { + this.sessionId = getUuid().toString().replace("-",""); + } + return this.sessionId; + } + + protected UUID getUuid() { + if (this.uuid == null) { + this.uuid = idGenerator.generateId(); + } + return this.uuid; + } + + public URI getInfoUrl() { + return UriComponentsBuilder.fromUri(this.sockJsUrl) + .scheme(getScheme(TransportType.XHR)) + .pathSegment("info") + .build(true).toUri(); + } + + public URI getTransportUrl(TransportType transportType) { + return UriComponentsBuilder.fromUri(this.sockJsUrl) + .scheme(getScheme(transportType)) + .pathSegment(getServerId()) + .pathSegment(getSessionId()) + .pathSegment(transportType.toString()) + .build(true).toUri(); + } + + private String getScheme(TransportType transportType) { + String scheme = this.sockJsUrl.getScheme(); + if (TransportType.WEBSOCKET.equals(transportType)) { + if (!scheme.startsWith("ws")) { + scheme = ("https".equals(scheme) ? "wss" : "ws"); + } + } + else { + if (!scheme.startsWith("http")) { + scheme = ("wss".equals(scheme) ? "https" : "http"); + } + } + return scheme; + } + + + @Override + public String toString() { + return "SockJsUrlInfo[url=" + this.sockJsUrl + "]"; + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/Transport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/Transport.java new file mode 100644 index 0000000000..41d554a5ef --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/Transport.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.socket.sockjs.client; + + +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; + +/** + * A client-side implementation for a SockJS transport. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public interface Transport { + + /** + * Connect the transport. + * + * @param request the transport request. + * @param webSocketHandler the application handler to delegate lifecycle events to. + * @return a future to indicate success or failure to connect. + */ + ListenableFuture connect(TransportRequest request, WebSocketHandler webSocketHandler); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java new file mode 100644 index 0000000000..5e92b28296 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java @@ -0,0 +1,53 @@ +package org.springframework.web.socket.sockjs.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.URI; +import java.security.Principal; + +/** + * Represents a request to connect to a SockJS service using a specific + * Transport. A single SockJS request however may require falling back + * and therefore multiple TransportRequest instances. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public interface TransportRequest { + + /** + * Return information about the SockJS URL including server and session id.. + */ + SockJsUrlInfo getSockJsUrlInfo(); + + /** + * Return the headers to send with the connect request. + */ + HttpHeaders getHandshakeHeaders(); + + /** + * Return the transport URL for the given transport. + * For an {@link XhrTransport} this is the URL for receiving messages. + */ + URI getTransportUrl(); + + /** + * Return the user associated with the request, if any. + */ + Principal getUser(); + + /** + * Return the message codec to use for encoding SockJS messages. + */ + SockJsMessageCodec getMessageCodec(); + + /** + * Register a timeout cleanup task to invoke if the SockJS session is not + * fully established within the calculated retransmission timeout period. + */ + void addTimeoutTask(Runnable runnable); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketClientSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketClientSockJsSession.java new file mode 100644 index 0000000000..543a6a2c73 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketClientSockJsSession.java @@ -0,0 +1,136 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + + +import org.springframework.util.Assert; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.adapter.NativeWebSocketSession; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.List; + +/** + * An extension of {@link AbstractClientSockJsSession} wrapping and delegating + * to an actual WebSocket session. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class WebSocketClientSockJsSession extends AbstractClientSockJsSession implements NativeWebSocketSession { + + private WebSocketSession webSocketSession; + + + public WebSocketClientSockJsSession(TransportRequest request, WebSocketHandler handler, + SettableListenableFuture connectFuture) { + + super(request, handler, connectFuture); + } + + + @Override + public Object getNativeSession() { + return this.webSocketSession; + } + + @SuppressWarnings("unchecked") + @Override + public T getNativeSession(Class requiredType) { + if (requiredType != null) { + if (requiredType.isInstance(this.webSocketSession)) { + return (T) this.webSocketSession; + } + } + return null; + } + + @Override + public InetSocketAddress getLocalAddress() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getLocalAddress(); + } + + @Override + public InetSocketAddress getRemoteAddress() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getRemoteAddress(); + } + + @Override + public String getAcceptedProtocol() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getAcceptedProtocol(); + } + + @Override + public void setTextMessageSizeLimit(int messageSizeLimit) { + checkDelegateSessionInitialized(); + this.webSocketSession.setTextMessageSizeLimit(messageSizeLimit); + } + + @Override + public int getTextMessageSizeLimit() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getTextMessageSizeLimit(); + } + + @Override + public void setBinaryMessageSizeLimit(int messageSizeLimit) { + checkDelegateSessionInitialized(); + this.webSocketSession.setBinaryMessageSizeLimit(messageSizeLimit); + } + + @Override + public int getBinaryMessageSizeLimit() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getBinaryMessageSizeLimit(); + } + + @Override + public List getExtensions() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getExtensions(); + } + + private void checkDelegateSessionInitialized() { + Assert.state(this.webSocketSession != null, "WebSocketSession not yet initialized"); + } + + public void initializeDelegateSession(WebSocketSession session) { + this.webSocketSession = session; + } + + @Override + protected void sendInternal(TextMessage textMessage) throws IOException { + this.webSocketSession.sendMessage(textMessage); + } + + @Override + protected void disconnect(CloseStatus status) throws IOException { + if (this.webSocketSession != null) { + this.webSocketSession.close(status); + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketTransport.java new file mode 100644 index 0000000000..ef36a96840 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketTransport.java @@ -0,0 +1,129 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketHttpHeaders; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.client.WebSocketClient; +import org.springframework.web.socket.handler.TextWebSocketHandler; + +import java.net.URI; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A SockJS {@link Transport} that uses a + * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class WebSocketTransport implements Transport { + + private static Log logger = LogFactory.getLog(WebSocketTransport.class); + + private final WebSocketClient webSocketClient; + + + public WebSocketTransport(WebSocketClient webSocketClient) { + Assert.notNull(webSocketClient, "'webSocketClient' is required"); + this.webSocketClient = webSocketClient; + } + + + /** + * Return the configured {@code WebSocketClient}. + */ + public WebSocketClient getWebSocketClient() { + return this.webSocketClient; + } + + @Override + public ListenableFuture connect(TransportRequest request, WebSocketHandler handler) { + final SettableListenableFuture future = new SettableListenableFuture(); + WebSocketClientSockJsSession session = new WebSocketClientSockJsSession(request, handler, future); + handler = new ClientSockJsWebSocketHandler(session); + request.addTimeoutTask(session.getTimeoutTask()); + + URI url = request.getTransportUrl(); + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHandshakeHeaders()); + if (logger.isDebugEnabled()) { + logger.debug("Opening WebSocket connection, url=" + url); + } + this.webSocketClient.doHandshake(handler, headers, url).addCallback( + new ListenableFutureCallback() { + @Override + public void onSuccess(WebSocketSession webSocketSession) { + // WebSocket session ready, SockJS Session not yet + } + @Override + public void onFailure(Throwable t) { + future.setException(t); + } + }); + return future; + } + + @Override + public String toString() { + return "WebSocketTransport[client=" + this.webSocketClient + "]"; + } + + + private static class ClientSockJsWebSocketHandler extends TextWebSocketHandler { + + private final WebSocketClientSockJsSession sockJsSession; + + private final AtomicInteger connectCount = new AtomicInteger(0); + + + private ClientSockJsWebSocketHandler(WebSocketClientSockJsSession session) { + Assert.notNull(session); + this.sockJsSession = session; + } + + @Override + public void afterConnectionEstablished(WebSocketSession webSocketSession) throws Exception { + Assert.isTrue(this.connectCount.compareAndSet(0, 1)); + this.sockJsSession.initializeDelegateSession(webSocketSession); + } + + @Override + public void handleTextMessage(WebSocketSession webSocketSession, TextMessage message) throws Exception { + this.sockJsSession.handleFrame(message.getPayload()); + } + + @Override + public void handleTransportError(WebSocketSession webSocketSession, Throwable ex) throws Exception { + this.sockJsSession.handleTransportError(ex); + } + + @Override + public void afterConnectionClosed(WebSocketSession webSocketSession, CloseStatus status) throws Exception { + this.sockJsSession.afterTransportClosed(status); + } + } + +} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java new file mode 100644 index 0000000000..92de51cea7 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java @@ -0,0 +1,111 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.springframework.util.Assert; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.util.List; + + +/** + * An extension of {@link AbstractClientSockJsSession} for use with HTTP + * transports simulating a WebSocket session. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class XhrClientSockJsSession extends AbstractClientSockJsSession { + + private final URI sendUrl; + + private final XhrTransport transport; + + private int textMessageSizeLimit = -1; + + private int binaryMessageSizeLimit = -1; + + + public XhrClientSockJsSession(TransportRequest request, WebSocketHandler handler, + XhrTransport transport, SettableListenableFuture connectFuture) { + + super(request, handler, connectFuture); + Assert.notNull(transport, "'restTemplate' is required"); + this.sendUrl = request.getSockJsUrlInfo().getTransportUrl(TransportType.XHR_SEND); + this.transport = transport; + } + + + @Override + public InetSocketAddress getLocalAddress() { + return null; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return new InetSocketAddress(getUri().getHost(), getUri().getPort()); + } + + @Override + public String getAcceptedProtocol() { + return null; + } + + @Override + public void setTextMessageSizeLimit(int messageSizeLimit) { + this.textMessageSizeLimit = messageSizeLimit; + } + + @Override + public int getTextMessageSizeLimit() { + return this.textMessageSizeLimit; + } + + @Override + public void setBinaryMessageSizeLimit(int messageSizeLimit) { + this.binaryMessageSizeLimit = -1; + } + + @Override + public int getBinaryMessageSizeLimit() { + return this.binaryMessageSizeLimit; + } + + @Override + public List getExtensions() { + return null; + } + + @Override + protected void sendInternal(TextMessage message) { + this.transport.executeSendRequest(this.sendUrl, message); + } + + @Override + protected void disconnect(CloseStatus status) { + // Nothing to do, XHR transports check if session is disconnected + } + +} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java new file mode 100644 index 0000000000..726b202464 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java @@ -0,0 +1,40 @@ +package org.springframework.web.socket.sockjs.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.ResponseEntity; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; + +import java.net.URI; + +/** + * A SockJS {@link Transport} that uses HTTP requests to simulate a WebSocket + * interaction. The {@code connect} method of the base {@code Transport} interface + * is used to receive messages from the server while the + * {@link #executeSendRequest(java.net.URI, org.springframework.web.socket.TextMessage) + * executeSendRequest(URI, TextMessage)} method here is used to send messages. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public interface XhrTransport extends Transport, InfoReceiver { + + /** + * An {@code XhrTransport} supports both the "xhr_streaming" and "xhr" SockJS + * server transports. From a client perspective there is no implementation + * difference. + * + *

By default an {@code XhrTransport} will be used with "xhr_streaming" + * first and then with "xhr", if the streaming fails to connect. In some + * cases it may be useful to suppress streaming so that only "xhr" is used. + */ + boolean isXhrStreamingDisabled(); + + /** + * Execute a request to send the message to the server. + * @param transportUrl the URL for sending messages. + * @param message the message to send + */ + void executeSendRequest(URI transportUrl, TextMessage message); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/package-info.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/package-info.java new file mode 100644 index 0000000000..6ec13fd08e --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * SockJS client implementation of + * {@link org.springframework.web.socket.client.WebSocketClient}. + */ +package org.springframework.web.socket.sockjs.client; + diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java index 7e07bf7ced..99888fcdd5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java @@ -145,7 +145,7 @@ public class SockJsFrame { if (!(other instanceof SockJsFrame)) { return false; } - return this.content.equals(((SockJsFrame) other).content); + return (this.type.equals(((SockJsFrame) other).type) && this.content.equals(((SockJsFrame) other).content)); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index d9a06db592..dc8d020b74 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -238,7 +238,7 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem } else { response.setStatusCode(HttpStatus.NOT_FOUND); - logger.warn("Session not found"); + logger.warn("Session not found, sessionId=" + sessionId); return; } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java index 9bea3bcd76..f4b2fd3f91 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java @@ -78,6 +78,7 @@ public class JettyWebSocketTestServer implements WebSocketTestServer { @Override public void stop() throws Exception { if (this.jettyServer.isRunning()) { + this.jettyServer.setStopTimeout(0); this.jettyServer.stop(); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java new file mode 100644 index 0000000000..016b7ef924 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java @@ -0,0 +1,394 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.WebSocketTestServer; +import org.springframework.web.socket.config.annotation.EnableWebSocket; +import org.springframework.web.socket.config.annotation.WebSocketConfigurer; +import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; +import org.springframework.web.socket.handler.TextWebSocketHandler; +import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.RequestUpgradeStrategy; +import org.springframework.web.socket.server.support.DefaultHandshakeHandler; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.hamcrest.Matchers.*; + +/** + * Integration tests using the + * {@link org.springframework.web.socket.sockjs.client.SockJsClient}. + * against actual SockJS server endpoints. + * + * @author Rossen Stoyanchev + */ +public abstract class AbstractSockJsIntegrationTests { + + protected Log logger = LogFactory.getLog(getClass()); + + private WebSocketTestServer server; + + private AnnotationConfigWebApplicationContext wac; + + private ErrorFilter errorFilter; + + private String baseUrl; + + + @Before + public void setup() throws Exception { + this.errorFilter = new ErrorFilter(); + this.wac = new AnnotationConfigWebApplicationContext(); + this.wac.register(TestConfig.class, upgradeStrategyConfigClass()); + this.wac.refresh(); + this.server = createWebSocketTestServer(); + this.server.deployConfig(this.wac, this.errorFilter); + this.server.start(); + this.baseUrl = "http://localhost:" + this.server.getPort(); + } + + @After + public void teardown() throws Exception { + try { + this.server.undeployConfig(); + } + catch (Throwable t) { + logger.error("Failed to undeploy application config", t); + } + try { + this.server.stop(); + } + catch (Throwable t) { + logger.error("Failed to stop server", t); + } + } + + protected abstract WebSocketTestServer createWebSocketTestServer(); + + protected abstract Class upgradeStrategyConfigClass(); + + protected abstract Transport getWebSocketTransport(); + + protected abstract AbstractXhrTransport getXhrTransport(); + + protected SockJsClient createSockJsClient(Transport... transports) { + return new SockJsClient(Arrays.asList(transports)); + } + + @Test + public void echoWebSocket() throws Exception { + testEcho(100, getWebSocketTransport()); + } + + @Test + public void echoXhrStreaming() throws Exception { + testEcho(100, getXhrTransport()); + } + + @Test + public void echoXhr() throws Exception { + AbstractXhrTransport xhrTransport = getXhrTransport(); + xhrTransport.setXhrStreamingDisabled(true); + testEcho(100, xhrTransport); + } + + @Test + public void closeAfterOneMessageWebSocket() throws Exception { + testCloseAfterOneMessage(getWebSocketTransport()); + } + + @Test + public void closeAfterOneMessageXhrStreaming() throws Exception { + testCloseAfterOneMessage(getXhrTransport()); + } + + @Test + public void closeAfterOneMessageXhr() throws Exception { + AbstractXhrTransport xhrTransport = getXhrTransport(); + xhrTransport.setXhrStreamingDisabled(true); + testCloseAfterOneMessage(xhrTransport); + } + + @Test + public void infoRequestFailure() throws Exception { + TestClientHandler handler = new TestClientHandler(); + this.errorFilter.responseStatusMap.put("/info", 500); + CountDownLatch latch = new CountDownLatch(1); + createSockJsClient(getWebSocketTransport()).doHandshake(handler, this.baseUrl + "/echo").addCallback( + new ListenableFutureCallback() { + @Override + public void onSuccess(WebSocketSession result) { + + } + @Override + public void onFailure(Throwable t) { + latch.countDown(); + } + } + ); + assertTrue(latch.await(5000, TimeUnit.MILLISECONDS)); + } + + @Test + public void fallbackAfterTransportFailure() throws Exception { + this.errorFilter.responseStatusMap.put("/websocket", 200); + this.errorFilter.responseStatusMap.put("/xhr_streaming", 500); + TestClientHandler handler = new TestClientHandler(); + Transport[] transports = { getWebSocketTransport(), getXhrTransport() }; + WebSocketSession session = createSockJsClient(transports).doHandshake(handler, this.baseUrl + "/echo").get(); + assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, session.getClass()); + TextMessage message = new TextMessage("message1"); + session.sendMessage(message); + handler.awaitMessage(message, 5000); + } + + @Test(timeout = 5000) + public void fallbackAfterConnectTimeout() throws Exception { + TestClientHandler clientHandler = new TestClientHandler(); + this.errorFilter.sleepDelayMap.put("/xhr_streaming", 10000L); + this.errorFilter.responseStatusMap.put("/xhr_streaming", 503); + SockJsClient sockJsClient = createSockJsClient(getXhrTransport()); + sockJsClient.setTaskScheduler(this.wac.getBean(ThreadPoolTaskScheduler.class)); + WebSocketSession clientSession = sockJsClient.doHandshake(clientHandler, this.baseUrl + "/echo").get(); + assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, clientSession.getClass()); + TextMessage message = new TextMessage("message1"); + clientSession.sendMessage(message); + clientHandler.awaitMessage(message, 5000); + clientSession.close(); + } + + + private void testEcho(int messageCount, Transport transport) throws Exception { + List messages = new ArrayList<>(); + for (int i = 0; i < messageCount; i++) { + messages.add(new TextMessage("m" + i)); + } + TestClientHandler handler = new TestClientHandler(); + WebSocketSession session = createSockJsClient(transport).doHandshake(handler, this.baseUrl + "/echo").get(); + for (TextMessage message : messages) { + session.sendMessage(message); + } + handler.awaitMessageCount(messageCount, 5000); + for (TextMessage message : messages) { + assertTrue("Message not received: " + message, handler.receivedMessages.remove(message)); + } + assertEquals("Remaining messages: " + handler.receivedMessages, 0, handler.receivedMessages.size()); + session.close(); + } + + private void testCloseAfterOneMessage(Transport transport) throws Exception { + TestClientHandler clientHandler = new TestClientHandler(); + createSockJsClient(transport).doHandshake(clientHandler, this.baseUrl + "/test").get(); + TestServerHandler serverHandler = this.wac.getBean(TestServerHandler.class); + + assertNotNull("afterConnectionEstablished should have been called", clientHandler.session); + serverHandler.awaitSession(5000); + + TextMessage message = new TextMessage("message1"); + serverHandler.session.sendMessage(message); + clientHandler.awaitMessage(message, 5000); + + CloseStatus expected = new CloseStatus(3500, "Oops"); + serverHandler.session.close(expected); + CloseStatus actual = clientHandler.awaitCloseStatus(5000); + if (transport instanceof XhrTransport) { + assertThat(actual, Matchers.anyOf(equalTo(expected), equalTo(new CloseStatus(3000, "Go away!")))); + } + else { + assertEquals(expected, actual); + } + } + + + @Configuration + @EnableWebSocket + static class TestConfig implements WebSocketConfigurer { + + @Autowired + private RequestUpgradeStrategy upgradeStrategy; + + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + HandshakeHandler handshakeHandler = new DefaultHandshakeHandler(this.upgradeStrategy); + registry.addHandler(new EchoHandler(), "/echo").setHandshakeHandler(handshakeHandler).withSockJS(); + registry.addHandler(testServerHandler(), "/test").setHandshakeHandler(handshakeHandler).withSockJS(); + } + + @Bean + public TestServerHandler testServerHandler() { + return new TestServerHandler(); + } + } + + private static interface Condition { + boolean match(); + } + + private static void awaitEvent(Condition condition, long timeToWait, String description) { + long timeToSleep = 200; + for (int i = 0 ; i < Math.floor(timeToWait / timeToSleep); i++) { + if (condition.match()) { + return; + } + try { + Thread.sleep(timeToSleep); + } + catch (InterruptedException e) { + throw new IllegalStateException("Interrupted while waiting for " + description, e); + } + } + throw new IllegalStateException("Timed out waiting for " + description); + } + + private static class TestClientHandler extends TextWebSocketHandler { + + private final BlockingQueue receivedMessages = new LinkedBlockingQueue<>(); + + private volatile WebSocketSession session; + + private volatile CloseStatus closeStatus; + + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + this.session = session; + } + + @Override + protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { + this.receivedMessages.add(message); + } + + @Override + public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { + this.closeStatus = status; + } + + public void awaitMessageCount(final int count, long timeToWait) throws Exception { + awaitEvent(() -> receivedMessages.size() >= count, timeToWait, + count + " number of messages. Received so far: " + this.receivedMessages); + } + + public void awaitMessage(TextMessage expected, long timeToWait) throws InterruptedException { + TextMessage actual = this.receivedMessages.poll(timeToWait, TimeUnit.MILLISECONDS); + assertNotNull("Timed out waiting for [" + expected + "]", actual); + assertEquals(expected, actual); + } + + public CloseStatus awaitCloseStatus(long timeToWait) throws InterruptedException { + awaitEvent(() -> this.closeStatus != null, timeToWait, " CloseStatus"); + return this.closeStatus; + } + } + + private static class TestServerHandler extends TextWebSocketHandler { + + private WebSocketSession session; + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + this.session = session; + } + + public WebSocketSession awaitSession(long timeToWait) throws InterruptedException { + awaitEvent(() -> this.session != null, timeToWait, " session"); + return this.session; + } + } + + private static class EchoHandler extends TextWebSocketHandler { + + @Override + protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { + session.sendMessage(message); + } + } + + private static class ErrorFilter implements Filter { + + private final Map responseStatusMap = new HashMap<>(); + + private final Map sleepDelayMap = new HashMap<>(); + + @Override + public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain) throws IOException, ServletException { + for (String suffix : this.sleepDelayMap.keySet()) { + if (((HttpServletRequest) req).getRequestURI().endsWith(suffix)) { + try { + Thread.sleep(this.sleepDelayMap.get(suffix)); + break; + } + catch (InterruptedException e) { + e.printStackTrace(); + } + } + } + for (String suffix : this.responseStatusMap.keySet()) { + if (((HttpServletRequest) req).getRequestURI().endsWith(suffix)) { + ((HttpServletResponse) resp).sendError(this.responseStatusMap.get(suffix)); + return; + } + } + chain.doFilter(req, resp); + } + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + } + + @Override + public void destroy() { + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java new file mode 100644 index 0000000000..9d3f76b460 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java @@ -0,0 +1,280 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.util.List; + +import static org.junit.Assert.assertThat; +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.*; + +/** + * Unit tests for + * {@link org.springframework.web.socket.sockjs.client.AbstractClientSockJsSession}. + * + * @author Rossen Stoyanchev + */ +public class ClientSockJsSessionTests { + + private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec(); + + private TestClientSockJsSession session; + + private WebSocketHandler handler; + + private SettableListenableFuture connectFuture; + + @Rule + public final ExpectedException thrown = ExpectedException.none(); + + + @Before + public void setup() throws Exception { + SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); + Transport transport = mock(Transport.class); + TransportRequest request = new DefaultTransportRequest(urlInfo, null, transport, TransportType.XHR, CODEC); + this.handler = mock(WebSocketHandler.class); + this.connectFuture = new SettableListenableFuture<>(); + this.session = new TestClientSockJsSession(request, this.handler, this.connectFuture); + } + + + @Test + public void handleFrameOpen() throws Exception { + assertThat(this.session.isOpen(), is(false)); + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + assertThat(this.session.isOpen(), is(true)); + assertTrue(this.connectFuture.isDone()); + assertThat(this.connectFuture.get(), sameInstance(this.session)); + verify(this.handler).afterConnectionEstablished(this.session); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameOpenWhenStatusNotNew() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + assertThat(this.session.isOpen(), is(true)); + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(1006, "Server lost session"))); + } + + @Test + public void handleFrameOpenWithWebSocketHandlerException() throws Exception { + doThrow(new IllegalStateException("Fake error")).when(this.handler).afterConnectionEstablished(this.session); + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + assertThat(this.session.isOpen(), is(true)); + } + + @Test + public void handleFrameMessage() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent()); + verify(this.handler).afterConnectionEstablished(this.session); + verify(this.handler).handleMessage(this.session, new TextMessage("foo")); + verify(this.handler).handleMessage(this.session, new TextMessage("bar")); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameMessageWhenNotOpen() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.close(); + reset(this.handler); + this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent()); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameMessageWithBadData() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.handleFrame("a['bad data"); + assertThat(this.session.isOpen(), equalTo(false)); + assertThat(this.session.disconnectStatus, equalTo(CloseStatus.BAD_DATA)); + verify(this.handler).afterConnectionEstablished(this.session); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameMessageWithWebSocketHandlerException() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + doThrow(new IllegalStateException("Fake error")).when(this.handler).handleMessage(this.session, new TextMessage("foo")); + doThrow(new IllegalStateException("Fake error")).when(this.handler).handleMessage(this.session, new TextMessage("bar")); + this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent()); + assertThat(this.session.isOpen(), equalTo(true)); + verify(this.handler).afterConnectionEstablished(this.session); + verify(this.handler).handleMessage(this.session, new TextMessage("foo")); + verify(this.handler).handleMessage(this.session, new TextMessage("bar")); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameClose() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.handleFrame(SockJsFrame.closeFrame(1007, "").getContent()); + assertThat(this.session.isOpen(), equalTo(false)); + assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(1007, ""))); + verify(this.handler).afterConnectionEstablished(this.session); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleTransportError() throws Exception { + final IllegalStateException ex = new IllegalStateException("Fake error"); + this.session.handleTransportError(ex); + verify(this.handler).handleTransportError(this.session, ex); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void afterTransportClosed() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.afterTransportClosed(CloseStatus.SERVER_ERROR); + assertThat(this.session.isOpen(), equalTo(false)); + verify(this.handler).afterConnectionEstablished(this.session); + verify(this.handler).afterConnectionClosed(this.session, CloseStatus.SERVER_ERROR); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void close() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.close(); + assertThat(this.session.isOpen(), equalTo(false)); + assertThat(this.session.disconnectStatus, equalTo(CloseStatus.NORMAL)); + verify(this.handler).afterConnectionEstablished(this.session); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void closeWithStatus() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.close(new CloseStatus(3000, "reason")); + assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(3000, "reason"))); + } + + @Test + public void closeWithNullStatus() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("Invalid close status"); + this.session.close(null); + } + + @Test + public void closeWithStatusOutOfRange() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("Invalid close status"); + this.session.close(new CloseStatus(2999, "reason")); + } + + @Test + public void timeoutTask() { + this.session.getTimeoutTask().run(); + assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(2007, "Transport timed out"))); + } + + @Test + public void send() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.sendMessage(new TextMessage("foo")); + assertThat(this.session.sentMessage, equalTo(new TextMessage("[\"foo\"]"))); + } + + + private static class TestClientSockJsSession extends AbstractClientSockJsSession { + + private TextMessage sentMessage; + + private CloseStatus disconnectStatus; + + + protected TestClientSockJsSession(TransportRequest request, WebSocketHandler handler, + SettableListenableFuture connectFuture) { + super(request, handler, connectFuture); + } + + @Override + protected void sendInternal(TextMessage textMessage) throws IOException { + this.sentMessage = textMessage; + } + + @Override + protected void disconnect(CloseStatus status) throws IOException { + this.disconnectStatus = status; + } + + @Override + public InetSocketAddress getLocalAddress() { + return null; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return null; + } + + @Override + public String getAcceptedProtocol() { + return null; + } + + @Override + public void setTextMessageSizeLimit(int messageSizeLimit) { + + } + + @Override + public int getTextMessageSizeLimit() { + return 0; + } + + @Override + public void setBinaryMessageSizeLimit(int messageSizeLimit) { + + } + + @Override + public int getBinaryMessageSizeLimit() { + return 0; + } + + @Override + public List getExtensions() { + return null; + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java new file mode 100644 index 0000000000..6d1eeda631 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.io.IOException; +import java.net.URI; +import java.util.Date; +import java.util.concurrent.ExecutionException; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; + +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Unit tests for {@link DefaultTransportRequest}. + * + * @author Rossen Stoyanchev + */ +public class DefaultTransportRequestTests { + + private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec(); + + + private SettableListenableFuture connectFuture; + + private ListenableFutureCallback connectCallback; + + private TestTransport webSocketTransport; + + private TestTransport xhrTransport; + + + @Rule + public final ExpectedException thrown = ExpectedException.none(); + + + @SuppressWarnings("unchecked") + @Before + public void setup() throws Exception { + this.connectCallback = mock(ListenableFutureCallback.class); + this.connectFuture = new SettableListenableFuture<>(); + this.connectFuture.addCallback(this.connectCallback); + this.webSocketTransport = new TestTransport("WebSocketTestTransport"); + this.xhrTransport = new TestTransport("XhrTestTransport"); + } + + + @Test + @SuppressWarnings("unchecked") + public void connect() throws Exception { + DefaultTransportRequest request = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET); + request.connect(null, this.connectFuture); + WebSocketSession session = mock(WebSocketSession.class); + this.webSocketTransport.getConnectCallback().onSuccess(session); + assertSame(session, this.connectFuture.get()); + } + + @Test + public void fallbackAfterTransportError() throws Exception { + DefaultTransportRequest request1 = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET); + DefaultTransportRequest request2 = createTransportRequest(this.xhrTransport, TransportType.XHR_STREAMING); + request1.setFallbackRequest(request2); + request1.connect(null, this.connectFuture); + + // Transport error => fallback + this.webSocketTransport.getConnectCallback().onFailure(new IOException("Fake exception 1")); + assertFalse(this.connectFuture.isDone()); + assertTrue(this.xhrTransport.invoked()); + + // Transport error => no more fallback + this.xhrTransport.getConnectCallback().onFailure(new IOException("Fake exception 2")); + assertTrue(this.connectFuture.isDone()); + this.thrown.expect(ExecutionException.class); + this.thrown.expectMessage("Fake exception 2"); + this.connectFuture.get(); + } + + @Test + public void fallbackAfterTimeout() throws Exception { + TaskScheduler scheduler = mock(TaskScheduler.class); + Runnable sessionCleanupTask = mock(Runnable.class); + DefaultTransportRequest request1 = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET); + DefaultTransportRequest request2 = createTransportRequest(this.xhrTransport, TransportType.XHR_STREAMING); + request1.setFallbackRequest(request2); + request1.setTimeoutScheduler(scheduler); + request1.addTimeoutTask(sessionCleanupTask); + request1.connect(null, this.connectFuture); + + assertTrue(this.webSocketTransport.invoked()); + assertFalse(this.xhrTransport.invoked()); + + // Get and invoke the scheduled timeout task + ArgumentCaptor taskCaptor = ArgumentCaptor.forClass(Runnable.class); + verify(scheduler).schedule(taskCaptor.capture(), any(Date.class)); + verifyNoMoreInteractions(scheduler); + taskCaptor.getValue().run(); + + assertTrue(this.xhrTransport.invoked()); + verify(sessionCleanupTask).run(); + } + + protected DefaultTransportRequest createTransportRequest(Transport transport, TransportType type) throws Exception { + SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); + return new DefaultTransportRequest(urlInfo, new HttpHeaders(), transport, type, CODEC); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java new file mode 100644 index 0000000000..89120794ec --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.junit.After; +import org.junit.Before; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.socket.JettyWebSocketTestServer; +import org.springframework.web.socket.client.jetty.JettyWebSocketClient; +import org.springframework.web.socket.server.RequestUpgradeStrategy; +import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy; + +import java.util.ArrayList; +import java.util.List; + +/** + * SockJS integration tests using Jetty for client and server. + * + * @author Rossen Stoyanchev + */ +public class JettySockJsIntegrationTests extends AbstractSockJsIntegrationTests { + + private WebSocketClient webSocketClient; + + private HttpClient httpClient; + + + @Before + public void setup() throws Exception { + super.setup(); + this.webSocketClient = new WebSocketClient(); + this.webSocketClient.start(); + this.httpClient = new HttpClient(); + this.httpClient.start(); + } + + @After + public void teardown() throws Exception { + super.teardown(); + try { + this.webSocketClient.stop(); + } + catch (Throwable ex) { + logger.error("Failed to stop Jetty WebSocketClient", ex); + } + try { + this.httpClient.stop(); + } + catch (Throwable ex) { + logger.error("Failed to stop Jetty HttpClient", ex); + } + } + + @Override + protected JettyWebSocketTestServer createWebSocketTestServer() { + return new JettyWebSocketTestServer(); + } + + @Override + protected Class upgradeStrategyConfigClass() { + return JettyTestConfig.class; + } + + @Override + protected Transport getWebSocketTransport() { + return new WebSocketTransport(new JettyWebSocketClient(this.webSocketClient)); + } + + @Override + protected AbstractXhrTransport getXhrTransport() { + return new JettyXhrTransport(this.httpClient); + } + + + @Configuration + static class JettyTestConfig { + + @Bean + public RequestUpgradeStrategy upgradeStrategy() { + return new JettyRequestUpgradeStrategy(); + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java new file mode 100644 index 0000000000..f5b9318d4a --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java @@ -0,0 +1,228 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompEncoder; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.client.RequestCallback; +import org.springframework.web.client.ResponseExtractor; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingDeque; + +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Unit tests for {@link RestTemplateXhrTransport}. + * + * @author Rossen Stoyanchev + */ +public class RestTemplateXhrTransportTests { + + private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec(); + + private WebSocketHandler webSocketHandler; + + + @Before + public void setup() throws Exception { + this.webSocketHandler = mock(WebSocketHandler.class); + } + + + @Test + public void connectReceiveAndClose() throws Exception { + String body = "o\n" + "a[\"foo\"]\n" + "c[3000,\"Go away!\"]"; + ClientHttpResponse response = response(HttpStatus.OK, body); + connect(response); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).handleMessage(any(), eq(new TextMessage("foo"))); + verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!"))); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void connectReceiveAndCloseWithPrelude() throws Exception { + StringBuilder sb = new StringBuilder(2048); + for (int i=0; i < 2048; i++) { + sb.append('h'); + } + String body = sb.toString() + "\n" + "o\n" + "a[\"foo\"]\n" + "c[3000,\"Go away!\"]"; + ClientHttpResponse response = response(HttpStatus.OK, body); + connect(response); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).handleMessage(any(), eq(new TextMessage("foo"))); + verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!"))); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void connectReceiveAndCloseWithStompFrame() throws Exception { + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND); + accessor.setDestination("/destination"); + MessageHeaders headers = accessor.getMessageHeaders(); + Message message = MessageBuilder.createMessage("body".getBytes(Charset.forName("UTF-8")), headers); + byte[] bytes = new StompEncoder().encode(message); + TextMessage textMessage = new TextMessage(bytes); + SockJsFrame frame = SockJsFrame.messageFrame(new Jackson2SockJsMessageCodec(), textMessage.getPayload()); + + String body = "o\n" + frame.getContent() + "\n" + "c[3000,\"Go away!\"]"; + ClientHttpResponse response = response(HttpStatus.OK, body); + connect(response); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).handleMessage(any(), eq(textMessage)); + verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!"))); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void connectFailure() throws Exception { + final HttpServerErrorException expected = new HttpServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR); + RestOperations restTemplate = mock(RestOperations.class); + when(restTemplate.execute(any(), eq(HttpMethod.POST), any(), any())).thenThrow(expected); + + final CountDownLatch latch = new CountDownLatch(1); + connect(restTemplate).addCallback( + new ListenableFutureCallback() { + @Override + public void onSuccess(WebSocketSession result) { + } + @Override + public void onFailure(Throwable actual) { + if (actual == expected) { + latch.countDown(); + } + } + } + ); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void errorResponseStatus() throws Exception { + connect(response(HttpStatus.OK, "o\n"), response(HttpStatus.INTERNAL_SERVER_ERROR, "Oops")); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).handleTransportError(any(), any()); + verify(this.webSocketHandler).afterConnectionClosed(any(), any()); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void responseClosedAfterDisconnected() throws Exception { + String body = "o\n" + "c[3000,\"Go away!\"]\n" + "a[\"foo\"]\n"; + ClientHttpResponse response = response(HttpStatus.OK, body); + connect(response); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).afterConnectionClosed(any(), any()); + verifyNoMoreInteractions(this.webSocketHandler); + verify(response).close(); + } + + private ListenableFuture connect(ClientHttpResponse... responses) throws Exception { + return connect(new TestRestTemplate(responses)); + } + + private ListenableFuture connect(RestOperations restTemplate, ClientHttpResponse... responses) + throws Exception { + + RestTemplateXhrTransport transport = new RestTemplateXhrTransport(restTemplate); + transport.setTaskExecutor(new SyncTaskExecutor()); + + SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); + HttpHeaders headers = new HttpHeaders(); + headers.add("h-foo", "h-bar"); + TransportRequest request = new DefaultTransportRequest(urlInfo, headers, transport, TransportType.XHR, CODEC); + + return transport.connect(request, this.webSocketHandler); + } + + private ClientHttpResponse response(HttpStatus status, String body) throws IOException { + ClientHttpResponse response = mock(ClientHttpResponse.class); + InputStream inputStream = getInputStream(body); + when(response.getStatusCode()).thenReturn(status); + when(response.getBody()).thenReturn(inputStream); + return response; + } + + private InputStream getInputStream(String content) { + byte[] bytes = content.getBytes(Charset.forName("UTF-8")); + return new ByteArrayInputStream(bytes); + } + + + + private static class TestRestTemplate extends RestTemplate { + + private Queue responses = new LinkedBlockingDeque<>(); + + + private TestRestTemplate(ClientHttpResponse... responses) { + this.responses.addAll(Arrays.asList(responses)); + } + + @Override + public T execute(URI url, HttpMethod method, RequestCallback callback, ResponseExtractor extractor) throws RestClientException { + try { + extractor.extractData(this.responses.remove()); + } + catch (Throwable t) { + throw new RestClientException("Failed to invoke extractor", t); + } + return null; + } + } + + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java new file mode 100644 index 0000000000..b304257803 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpStatus; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.client.TestTransport.XhrTestTransport; + +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link org.springframework.web.socket.sockjs.client.SockJsClient}. + * + * @author Rossen Stoyanchev + */ +public class SockJsClientTests { + + private static final String URL = "http://example.com"; + + private static final WebSocketHandler handler = mock(WebSocketHandler.class); + + + private SockJsClient sockJsClient; + + private InfoReceiver infoReceiver; + + private TestTransport webSocketTransport; + + private XhrTestTransport xhrTransport; + + private ListenableFutureCallback connectCallback; + + + @Before + @SuppressWarnings("unchecked") + public void setup() { + this.infoReceiver = mock(InfoReceiver.class); + this.webSocketTransport = new TestTransport("WebSocketTestTransport"); + this.xhrTransport = new XhrTestTransport("XhrTestTransport"); + + List transports = new ArrayList<>(); + transports.add(this.webSocketTransport); + transports.add(this.xhrTransport); + this.sockJsClient = new SockJsClient(transports); + this.sockJsClient.setInfoReceiver(this.infoReceiver); + + this.connectCallback = mock(ListenableFutureCallback.class); + } + + @Test + public void connectWebSocket() throws Exception { + setupInfoRequest(true); + this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback); + assertTrue(this.webSocketTransport.invoked()); + WebSocketSession session = mock(WebSocketSession.class); + this.webSocketTransport.getConnectCallback().onSuccess(session); + verify(this.connectCallback).onSuccess(session); + verifyNoMoreInteractions(this.connectCallback); + } + + @Test + public void connectWebSocketDisabled() throws URISyntaxException { + setupInfoRequest(false); + this.sockJsClient.doHandshake(handler, URL); + assertFalse(this.webSocketTransport.invoked()); + assertTrue(this.xhrTransport.invoked()); + assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr_streaming")); + } + + @Test + public void connectXhrStreamingDisabled() throws Exception { + setupInfoRequest(false); + this.xhrTransport.setStreamingDisabled(true); + this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback); + assertFalse(this.webSocketTransport.invoked()); + assertTrue(this.xhrTransport.invoked()); + assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr")); + } + + @Test + public void connectSockJsInfo() throws Exception { + setupInfoRequest(true); + this.sockJsClient.doHandshake(handler, URL); + verify(this.infoReceiver, times(1)).executeInfoRequest(any()); + } + + @Test + public void connectSockJsInfoCached() throws Exception { + setupInfoRequest(true); + this.sockJsClient.doHandshake(handler, URL); + this.sockJsClient.doHandshake(handler, URL); + this.sockJsClient.doHandshake(handler, URL); + verify(this.infoReceiver, times(1)).executeInfoRequest(any()); + } + + @Test + @SuppressWarnings("unchecked") + public void connectInfoRequestFailure() throws URISyntaxException { + HttpServerErrorException exception = new HttpServerErrorException(HttpStatus.SERVICE_UNAVAILABLE); + when(this.infoReceiver.executeInfoRequest(any())).thenThrow(exception); + this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback); + verify(this.connectCallback).onFailure(exception); + assertFalse(this.webSocketTransport.invoked()); + assertFalse(this.xhrTransport.invoked()); + } + + private void setupInfoRequest(boolean webSocketEnabled) { + when(this.infoReceiver.executeInfoRequest(any())).thenReturn("{\"entropy\":123," + + "\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":" + webSocketEnabled + "}"); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfoTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfoTests.java new file mode 100644 index 0000000000..462e27b783 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfoTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Assert; +import org.junit.Test; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.URI; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for {@code SockJsUrlInfo}. + * @author Rossen Stoyanchev + */ +public class SockJsUrlInfoTests { + + + @Test + public void serverId() throws Exception { + SockJsUrlInfo info = new SockJsUrlInfo(new URI("http://example.com")); + int serverId = Integer.valueOf(info.getServerId()); + assertTrue("Invalid serverId: " + serverId, serverId > 0 && serverId < 1000); + } + + @Test + public void sessionId() throws Exception { + SockJsUrlInfo info = new SockJsUrlInfo(new URI("http://example.com")); + assertEquals("Invalid sessionId: " + info.getSessionId(), 32, info.getSessionId().length()); + } + + @Test + public void infoUrl() throws Exception { + testInfoUrl("http", "http"); + testInfoUrl("http", "http"); + testInfoUrl("https", "https"); + testInfoUrl("https", "https"); + testInfoUrl("ws", "http"); + testInfoUrl("ws", "http"); + testInfoUrl("wss", "https"); + testInfoUrl("wss", "https"); + } + + private void testInfoUrl(String scheme, String expectedScheme) throws Exception { + SockJsUrlInfo info = new SockJsUrlInfo(new URI(scheme + "://example.com")); + Assert.assertThat(info.getInfoUrl(), is(equalTo(new URI(expectedScheme + "://example.com/info")))); + } + + @Test + public void transportUrl() throws Exception { + testTransportUrl("http", "http", TransportType.XHR_STREAMING); + testTransportUrl("http", "ws", TransportType.WEBSOCKET); + testTransportUrl("https", "https", TransportType.XHR_STREAMING); + testTransportUrl("https", "wss", TransportType.WEBSOCKET); + testTransportUrl("ws", "http", TransportType.XHR_STREAMING); + testTransportUrl("ws", "ws", TransportType.WEBSOCKET); + testTransportUrl("wss", "https", TransportType.XHR_STREAMING); + testTransportUrl("wss", "wss", TransportType.WEBSOCKET); + } + + private void testTransportUrl(String scheme, String expectedScheme, TransportType transportType) throws Exception { + SockJsUrlInfo info = new SockJsUrlInfo(new URI(scheme + "://example.com")); + String serverId = info.getServerId(); + String sessionId = info.getSessionId(); + String transport = transportType.toString().toLowerCase(); + URI expected = new URI(expectedScheme + "://example.com/" + serverId + "/" + sessionId + "/" + transport); + assertThat(info.getTransportUrl(transportType), equalTo(expected)); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java new file mode 100644 index 0000000000..f54083b2ce --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.mockito.ArgumentCaptor; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; + +import java.net.URI; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Test SockJS Transport. + * + * @author Rossen Stoyanchev + */ +class TestTransport implements Transport { + + private final String name; + + private TransportRequest request; + + private ListenableFuture future; + + + public TestTransport(String name) { + this.name = name; + } + + public TransportRequest getRequest() { + return this.request; + } + + public boolean invoked() { + return this.future != null; + } + + @SuppressWarnings("unchecked") + public ListenableFutureCallback getConnectCallback() { + ArgumentCaptor captor = ArgumentCaptor.forClass(ListenableFutureCallback.class); + verify(this.future).addCallback(captor.capture()); + return captor.getValue(); + } + + @SuppressWarnings("unchecked") + @Override + public ListenableFuture connect(TransportRequest request, WebSocketHandler handler) { + this.request = request; + this.future = mock(ListenableFuture.class); + return this.future; + } + + @Override + public String toString() { + return "TestTransport[" + name + "]"; + } + + + static class XhrTestTransport extends TestTransport implements XhrTransport { + + private boolean streamingDisabled; + + + XhrTestTransport(String name) { + super(name); + } + + public void setStreamingDisabled(boolean streamingDisabled) { + this.streamingDisabled = streamingDisabled; + } + + @Override + public boolean isXhrStreamingDisabled() { + return this.streamingDisabled; + } + + @Override + public void executeSendRequest(URI transportUrl, TextMessage message) { + } + + @Override + public String executeInfoRequest(URI infoUrl) { + return null; + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java new file mode 100644 index 0000000000..d13bc207a1 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java @@ -0,0 +1,155 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; + +import java.net.URI; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.notNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +/** + * Unit tests for + * {@link org.springframework.web.socket.sockjs.client.AbstractXhrTransport}. + * + * @author Rossen Stoyanchev + */ +public class XhrTransportTests { + + @Test + public void infoResponse() throws Exception { + TestXhrTransport transport = new TestXhrTransport(); + transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK); + assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"))); + } + + @Test(expected = HttpServerErrorException.class) + public void infoResponseError() throws Exception { + TestXhrTransport transport = new TestXhrTransport(); + transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST); + assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"))); + } + + @Test + public void sendMessage() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("foo", "bar"); + TestXhrTransport transport = new TestXhrTransport(); + transport.setRequestHeaders(requestHeaders); + transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT); + URI url = new URI("http://example.com"); + transport.executeSendRequest(url, new TextMessage("payload")); + assertEquals(2, transport.actualSendRequestHeaders.size()); + assertEquals("bar", transport.actualSendRequestHeaders.getFirst("foo")); + assertEquals(MediaType.APPLICATION_JSON, transport.actualSendRequestHeaders.getContentType()); + } + + @Test(expected = HttpServerErrorException.class) + public void sendMessageError() throws Exception { + TestXhrTransport transport = new TestXhrTransport(); + transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST); + URI url = new URI("http://example.com"); + transport.executeSendRequest(url, new TextMessage("payload")); + } + + @Test + public void connect() throws Exception { + HttpHeaders handshakeHeaders = new HttpHeaders(); + handshakeHeaders.setOrigin("foo"); + + TransportRequest request = mock(TransportRequest.class); + when(request.getSockJsUrlInfo()).thenReturn(new SockJsUrlInfo(new URI("http://example.com"))); + when(request.getHandshakeHeaders()).thenReturn(handshakeHeaders); + + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("foo", "bar"); + + TestXhrTransport transport = new TestXhrTransport(); + transport.setRequestHeaders(requestHeaders); + + WebSocketHandler handler = mock(WebSocketHandler.class); + transport.connect(request, handler); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(request).getSockJsUrlInfo(); + verify(request).addTimeoutTask(captor.capture()); + verify(request).getTransportUrl(); + verify(request).getHandshakeHeaders(); + verifyNoMoreInteractions(request); + + assertEquals(2, transport.actualHandshakeHeaders.size()); + assertEquals("foo", transport.actualHandshakeHeaders.getOrigin()); + assertEquals("bar", transport.actualHandshakeHeaders.getFirst("foo")); + + assertFalse(transport.actualSession.isDisconnected()); + captor.getValue().run(); + assertTrue(transport.actualSession.isDisconnected()); + } + + + private static class TestXhrTransport extends AbstractXhrTransport { + + private ResponseEntity infoResponseToReturn; + + private ResponseEntity sendMessageResponseToReturn; + + private HttpHeaders actualSendRequestHeaders; + + private HttpHeaders actualHandshakeHeaders; + + private XhrClientSockJsSession actualSession; + + + @Override + protected ResponseEntity executeInfoRequestInternal(URI infoUrl) { + return this.infoResponseToReturn; + } + + @Override + protected ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) { + this.actualSendRequestHeaders = headers; + return this.sendMessageResponseToReturn; + } + + @Override + protected void connectInternal(TransportRequest request, WebSocketHandler handler, URI receiveUrl, + HttpHeaders handshakeHeaders, XhrClientSockJsSession session, + SettableListenableFuture connectFuture) { + + this.actualHandshakeHeaders = handshakeHeaders; + this.actualSession = session; + } + } + +} diff --git a/spring-websocket/src/test/resources/log4j.properties b/spring-websocket/src/test/resources/log4j.properties index 8db186fb4e..0b8d9ec5f6 100644 --- a/spring-websocket/src/test/resources/log4j.properties +++ b/spring-websocket/src/test/resources/log4j.properties @@ -1,9 +1,9 @@ log4j.appender.console=org.apache.log4j.ConsoleAppender log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} [%c] - %m%n +log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} [%c][%t] - %m%n log4j.rootCategory=WARN, console log4j.logger.org.springframework.web=DEBUG -log4j.logger.org.springframework.web.socket=DEBUG +log4j.logger.org.springframework.web.socket=TRACE log4j.logger.org.springframework.messaging=DEBUG From fcf6ae83285890f2e8e7de77dcc6e633129978e8 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 23 Jun 2014 03:14:00 -0400 Subject: [PATCH 5/7] Fix concurrency issues in SockJS session impls This change ensures the server "WebSocketHandler" is notified of the opening of a session before writing the open frame to the remote handler. Any messages sent by the server "WebSocketHandler" while getting notified of the opening get cached and flushed after the open frame has been written. This change introduces locking in AbtractHttpSockJsSession to guard access to the HTTP response. The goal is to prevent contention between client requests to receive messages (i.e. long polling) and the application trying to write. Issue: SPR-11916 --- .../session/AbstractHttpSockJsSession.java | 229 +++++++++--------- .../session/AbstractSockJsSession.java | 9 +- .../session/PollingSockJsSession.java | 23 +- .../session/StreamingSockJsSession.java | 32 +-- .../session/WebSocketServerSockJsSession.java | 53 +++- .../client/JettySockJsIntegrationTests.java | 3 - .../session/HttpSockJsSessionTests.java | 8 +- .../session/TestHttpSockJsSession.java | 5 + .../WebSocketServerSockJsSessionTests.java | 22 +- 9 files changed, 202 insertions(+), 182 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index bd3dccd613..c2f8de4718 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -30,7 +30,6 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpAsyncRequestControl; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; -import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; @@ -41,7 +40,7 @@ import org.springframework.web.socket.sockjs.frame.SockJsFrameFormat; import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; /** - * An abstract base class for use with HTTP transport based SockJS sessions. + * An abstract base class for use with HTTP transport SockJS sessions. * * @author Rossen Stoyanchev * @since 4.0 @@ -64,9 +63,12 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private volatile ServerHttpResponse response; + private volatile SockJsFrameFormat frameFormat; + + private volatile ServerHttpAsyncRequestControl asyncRequestControl; - private volatile SockJsFrameFormat frameFormat; + private final Object responseLock = new Object(); private volatile boolean requestInitialized; @@ -124,18 +126,10 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { return this.acceptedProtocol; } - /** - * Return response for the current request, or {@code null} if between requests. - */ - protected ServerHttpResponse getResponse() { - return this.response; - } - /** * Return the SockJS buffer for messages stored transparently between polling * requests. If the polling request takes longer than 5 seconds, the session - * will be closed. - * + * is closed. * @see org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService */ protected Queue getMessageCache() { @@ -173,123 +167,104 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { return Collections.emptyList(); } + /** + * Whether this HTTP transport streams message frames vs closing the response + * after each frame written (long polling). + */ + protected abstract boolean isStreaming(); + /** - * Handle the first HTTP request, i.e. the one that starts a SockJS session. - * Write a prelude to the response (if needed), send the SockJS "open" frame - * to indicate to the client the session is opened, and invoke the - * delegate WebSocketHandler to provide it with the newly opened session. - *

- * The "xhr" and "jsonp" (polling-based) transports completes the initial request - * as soon as the open frame is sent. Following that the client should start a - * successive polling request within the same SockJS session. - *

- * The "xhr_streaming", "eventsource", and "htmlfile" transports are streaming - * based and will leave the initial request open in order to stream one or - * more messages. However, even streaming based transports eventually recycle - * the long running request, after a certain number of bytes have been streamed - * (128K by default), and allow the client to start a successive request within - * the same SockJS session. + * Handle the first request for receiving messages on a SockJS HTTP transport + * based session. + * + *

Long polling-based transports (e.g. "xhr", "jsonp") complete the request + * after writing the open frame. Streaming-based transports ("xhr_streaming", + * "eventsource", and "htmlfile") leave the response open longer for further + * streaming of message frames but will also close it eventually after some + * amount of data has been sent. * * @param request the current request * @param response the current response * @param frameFormat the transport-specific SocksJS frame format to use - * - * @see #handleSuccessiveRequest(org.springframework.http.server.ServerHttpRequest, org.springframework.http.server.ServerHttpResponse, org.springframework.web.socket.sockjs.frame.SockJsFrameFormat) */ public void handleInitialRequest(ServerHttpRequest request, ServerHttpResponse response, SockJsFrameFormat frameFormat) throws SockJsException { - initRequest(request, response, frameFormat); - this.uri = request.getURI(); this.handshakeHeaders = request.getHeaders(); this.principal = request.getPrincipal(); this.localAddress = request.getLocalAddress(); this.remoteAddress = request.getRemoteAddress(); + this.response = response; + this.frameFormat = frameFormat; + this.asyncRequestControl = request.getAsyncRequestControl(response); + try { + // Let "our" handler know before sending the open frame to the remote handler + delegateConnectionEstablished(); writePrelude(request, response); writeFrame(SockJsFrame.openFrame()); + if (isStreaming() && !isClosed()) { + startAsyncRequest(); + } } catch (Throwable ex) { tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); - throw new SockJsTransportFailureException("Failed to send \"open\" frame", getId(), ex); - } - - try { - this.requestInitialized = true; - delegateConnectionEstablished(); - } - catch (Throwable ex) { - throw new SockJsException("Unhandled exception from WebSocketHandler", getId(), ex); + throw new SockJsTransportFailureException("Failed to open session", getId(), ex); } } - private void initRequest(ServerHttpRequest request, ServerHttpResponse response, - SockJsFrameFormat frameFormat) { - - Assert.notNull(request, "Request must not be null"); - Assert.notNull(response, "Response must not be null"); - Assert.notNull(frameFormat, "SockJsFrameFormat must not be null"); - - this.response = response; - this.frameFormat = frameFormat; - this.asyncRequestControl = request.getAsyncRequestControl(response); + protected void writePrelude(ServerHttpRequest request, ServerHttpResponse response) throws IOException { } - protected void writePrelude(ServerHttpRequest request, ServerHttpResponse response) throws IOException { + private void startAsyncRequest() { + this.asyncRequestControl.start(-1); + if (this.messageCache.size() > 0) { + flushCache(); + } + else { + scheduleHeartbeat(); + } + this.requestInitialized = true; } /** - * Handle all HTTP requests part of the same SockJS session except for the very - * first, initial request. Write a prelude (if needed) and keep the request - * open and ready to send a message from the server to the client. - *

- * The "xhr" and "jsonp" (polling-based) transports completes the request when - * the next message is sent, which could be an array of messages cached during - * the time between successive requests, or it could be a heartbeat message - * sent if no other messages were sent (by default within 25 seconds). - *

- * The "xhr_streaming", "eventsource", and "htmlfile" transports are streaming - * based and will leave the request open longer in order to stream messages over - * a period of time. However, even streaming based transports eventually recycle - * the long running request, after a certain number of bytes have been streamed - * (128K by default), and allow the client to start a successive request within - * the same SockJS session. + * Handle all requests, except the first one, to receive messages on a SockJS + * HTTP transport based session. + * + *

Long polling-based transports (e.g. "xhr", "jsonp") complete the request + * after writing any buffered message frames (or the next one). Streaming-based + * transports ("xhr_streaming", "eventsource", and "htmlfile") leave the + * response open longer for further streaming of message frames but will also + * close it eventually after some amount of data has been sent. * * @param request the current request * @param response the current response * @param frameFormat the transport-specific SocksJS frame format to use - * - * @see #handleInitialRequest(org.springframework.http.server.ServerHttpRequest, org.springframework.http.server.ServerHttpResponse, org.springframework.web.socket.sockjs.frame.SockJsFrameFormat) */ public void handleSuccessiveRequest(ServerHttpRequest request, ServerHttpResponse response, SockJsFrameFormat frameFormat) throws SockJsException { - initRequest(request, response, frameFormat); - try { - writePrelude(request, response); - } - catch (Throwable ex) { - tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); - throw new SockJsTransportFailureException("Failed to send \"open\" frame", getId(), ex); + synchronized (this.responseLock) { + try { + if (isClosed()) { + response.getBody().write(SockJsFrame.closeFrameGoAway().getContentBytes()); + } + this.response = response; + this.frameFormat = frameFormat; + this.asyncRequestControl = request.getAsyncRequestControl(response); + writePrelude(request, response); + startAsyncRequest(); + } + catch (Throwable ex) { + tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); + throw new SockJsTransportFailureException("Failed to handle SockJS receive request", getId(), ex); + } } - startAsyncRequest(); } - protected void startAsyncRequest() throws SockJsException { - try { - this.asyncRequestControl.start(-1); - this.requestInitialized = true; - scheduleHeartbeat(); - tryFlushCache(); - } - catch (Throwable ex) { - tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); - throw new SockJsTransportFailureException("Failed to flush messages", getId(), ex); - } - } @Override protected final void sendMessageInternal(String message) throws SockJsTransportFailureException { @@ -297,27 +272,35 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { tryFlushCache(); } - private void tryFlushCache() throws SockJsTransportFailureException { - if (this.messageCache.isEmpty()) { - logger.trace("Nothing to flush"); - return; - } - if (logger.isTraceEnabled()) { - logger.trace(this.messageCache.size() + " message(s) to flush"); - } - if (isActive() && this.requestInitialized) { - logger.trace("Flushing messages"); - flushCache(); - } - else { + private boolean tryFlushCache() throws SockJsTransportFailureException { + synchronized (this.responseLock) { + if (this.messageCache.isEmpty()) { + logger.trace("Nothing to flush in session=" + this.getId()); + return false; + } if (logger.isTraceEnabled()) { - logger.trace("Not ready to flush"); + logger.trace(this.messageCache.size() + " message(s) to flush in session " + this.getId()); + } + if (isActive() && this.requestInitialized) { + if (logger.isTraceEnabled()) { + logger.trace("Session is active, ready to flush."); + } + cancelHeartbeat(); + flushCache(); + return true; + } + else { + if (logger.isTraceEnabled()) { + logger.trace("Session is not active, not ready to flush."); + } + return false; } } } /** - * Only called if the connection is currently active + * Called when the connection is active and ready to write to the response. + * Sub-classes should implement but never call this method directly. */ protected abstract void flushCache() throws SockJsTransportFailureException; @@ -327,25 +310,27 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { } protected void resetRequest() { + synchronized (this.responseLock) { - this.requestInitialized = false; - updateLastActiveTime(); - - if (isActive()) { ServerHttpAsyncRequestControl control = this.asyncRequestControl; - if (control.isStarted()) { - try { - logger.debug("Completing asynchronous request"); - control.complete(); - } - catch (Throwable ex) { - logger.error("Failed to complete request: " + ex.getMessage()); + this.asyncRequestControl = null; + this.requestInitialized = false; + this.response = null; + + updateLastActiveTime(); + + if (control != null && !control.isCompleted()) { + if (control.isStarted()) { + try { + logger.debug("Completing asynchronous request"); + control.complete(); + } + catch (Throwable ex) { + logger.error("Failed to complete request: " + ex.getMessage()); + } } } } - - this.response = null; - this.asyncRequestControl = null; } @Override @@ -353,9 +338,15 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { if (isActive()) { String formattedFrame = this.frameFormat.format(frame); if (logger.isTraceEnabled()) { - logger.trace("Writing " + formattedFrame); + logger.trace("Writing to HTTP response: " + formattedFrame); + } + this.response.getBody().write(formattedFrame.getBytes(SockJsFrame.CHARSET)); + if (isStreaming()) { + this.response.flush(); + } + else { + resetRequest(); } - getResponse().getBody().write(formattedFrame.getBytes(SockJsFrame.CHARSET)); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java index 215a69f6c3..0a0ec5770a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java @@ -40,6 +40,7 @@ import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.sockjs.SockJsMessageDeliveryException; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; import org.springframework.web.socket.sockjs.frame.SockJsFrame; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; import org.springframework.web.socket.sockjs.transport.SockJsSession; @@ -142,6 +143,10 @@ public abstract class AbstractSockJsSession implements SockJsSession { return this.id; } + protected SockJsMessageCodec getMessageCodec() { + return this.config.getMessageCodec(); + } + public SockJsServiceConfig getSockJsServiceConfig() { return this.config; } @@ -420,7 +425,9 @@ public abstract class AbstractSockJsSession implements SockJsSession { @Override public String toString() { - return "SockJS session id=" + this.id; + long currentTime = System.currentTimeMillis(); + return "SockJsSession[id=" + this.id + ", state=" + this.state + ", sinceCreated=" + + (currentTime - this.timeCreated) + ", sinceLastActive=" + (currentTime - this.timeLastActive) + "]"; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java index 52e27829a6..fa2f721722 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java @@ -16,6 +16,8 @@ package org.springframework.web.socket.sockjs.transport.session; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.Queue; @@ -33,6 +35,7 @@ import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; */ public class PollingSockJsSession extends AbstractHttpSockJsSession { + public PollingSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler wsHandler, Map attributes) { @@ -41,22 +44,20 @@ public class PollingSockJsSession extends AbstractHttpSockJsSession { @Override - protected void flushCache() throws SockJsTransportFailureException { - cancelHeartbeat(); - Queue messageCache = getMessageCache(); - String[] messages = messageCache.toArray(new String[messageCache.size()]); - messageCache.clear(); + protected boolean isStreaming() { + return false; + } + @Override + protected void flushCache() throws SockJsTransportFailureException { + String[] messages = new String[getMessageCache().size()]; + for (int i = 0; i < messages.length; i++) { + messages[i] = getMessageCache().poll(); + } SockJsMessageCodec messageCodec = getSockJsServiceConfig().getMessageCodec(); SockJsFrame frame = SockJsFrame.messageFrame(messageCodec, messages); writeFrame(frame); } - @Override - protected void writeFrame(SockJsFrame frame) throws SockJsTransportFailureException { - super.writeFrame(frame); - resetRequest(); - } - } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java index 70facc8776..1e2d301dbe 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package org.springframework.web.socket.sockjs.transport.session; -import java.io.IOException; import java.util.Map; import org.springframework.http.server.ServerHttpRequest; @@ -48,21 +47,12 @@ public class StreamingSockJsSession extends AbstractHttpSockJsSession { @Override - public void handleInitialRequest(ServerHttpRequest request, ServerHttpResponse response, - SockJsFrameFormat frameFormat) throws SockJsException { - - super.handleInitialRequest(request, response, frameFormat); - - // the WebSocketHandler delegate may have closed the session - if (!isClosed()) { - super.startAsyncRequest(); - } + protected boolean isStreaming() { + return true; } @Override protected void flushCache() throws SockJsTransportFailureException { - cancelHeartbeat(); - do { String message = getMessageCache().poll(); SockJsMessageCodec messageCodec = getSockJsServiceConfig().getMessageCodec(); @@ -79,26 +69,12 @@ public class StreamingSockJsSession extends AbstractHttpSockJsSession { logger.trace("Streamed bytes limit reached. Recycling current request"); } resetRequest(); + this.byteCount = 0; break; } } while (!getMessageCache().isEmpty()); - scheduleHeartbeat(); } - @Override - protected void resetRequest() { - super.resetRequest(); - this.byteCount = 0; - } - - @Override - protected void writeFrameInternal(SockJsFrame frame) throws IOException { - if (isActive()) { - super.writeFrameInternal(frame); - getResponse().flush(); - } - } - } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java index 51e2df3f22..d250d9e34c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java @@ -22,6 +22,10 @@ import java.net.URI; import java.security.Principal; import java.util.List; import java.util.Map; +import java.util.Queue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import org.springframework.http.HttpHeaders; import org.springframework.util.Assert; @@ -34,7 +38,6 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.NativeWebSocketSession; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; import org.springframework.web.socket.sockjs.frame.SockJsFrame; -import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; /** @@ -47,6 +50,12 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession implemen private WebSocketSession webSocketSession; + private volatile boolean openFrameSent; + + private final Queue initSessionCache = new LinkedBlockingDeque(); + + private final Lock initSessionLock = new ReentrantLock(); + public WebSocketServerSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler handler, Map attributes) { @@ -143,15 +152,23 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession implemen public void initializeDelegateSession(WebSocketSession session) { - this.webSocketSession = session; - try { - TextMessage message = new TextMessage(SockJsFrame.openFrame().getContent()); - this.webSocketSession.sendMessage(message); - scheduleHeartbeat(); - delegateConnectionEstablished(); - } - catch (Exception ex) { - tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); + synchronized (this.initSessionLock) { + this.webSocketSession = session; + try { + // Let "our" handler know before sending the open frame to the remote handler + delegateConnectionEstablished(); + this.webSocketSession.sendMessage(new TextMessage(SockJsFrame.openFrame().getContent())); + + // Flush any messages cached in the mean time + while (!this.initSessionCache.isEmpty()) { + writeFrame(SockJsFrame.messageFrame(getMessageCodec(), this.initSessionCache.poll())); + } + scheduleHeartbeat(); + this.openFrameSent = true; + } + catch (Exception ex) { + tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); + } } } @@ -180,10 +197,20 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession implemen @Override public void sendMessageInternal(String message) throws SockJsTransportFailureException { + + // Open frame not sent yet? + // If in the session initialization thread, then cache, otherwise wait. + + if (!this.openFrameSent) { + synchronized (this.initSessionLock) { + if (!this.openFrameSent) { + this.initSessionCache.add(message); + return; + } + } + } cancelHeartbeat(); - SockJsMessageCodec messageCodec = getSockJsServiceConfig().getMessageCodec(); - SockJsFrame frame = SockJsFrame.messageFrame(messageCodec, message); - writeFrame(frame); + writeFrame(SockJsFrame.messageFrame(getMessageCodec(), message)); scheduleHeartbeat(); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java index 89120794ec..31e1868b30 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java @@ -27,9 +27,6 @@ import org.springframework.web.socket.client.jetty.JettyWebSocketClient; import org.springframework.web.socket.server.RequestUpgradeStrategy; import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy; -import java.util.ArrayList; -import java.util.List; - /** * SockJS integration tests using Jetty for client and server. * diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/HttpSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/HttpSockJsSessionTests.java index d0c96bdf4c..b2ec62daea 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/HttpSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/HttpSockJsSessionTests.java @@ -94,10 +94,8 @@ public class HttpSockJsSessionTests extends AbstractSockJsSessionTests expected = Arrays.asList(new TextMessage("o"), new TextMessage("a[\"go go\"]")); + assertEquals(expected, this.webSocketSession.getSentMessages()); + } + @Test public void handleMessageEmptyPayload() throws Exception { this.session.handleMessage(new TextMessage(""), this.webSocketSession); From 5d2e6f6d4ca9c79ad145c6dad28171b521f50221 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 23 Jun 2014 04:12:38 -0400 Subject: [PATCH 6/7] Prevent unbounded retention of cancelled SockJS tasks This change sets the removeOnCancelPolicy on the SockJS ScheduledThreadPoolExecutor to true. This ensures that cancelled tasks are removed immediately to avoid the "unbounded retention of cancelled tasks" that is mentioned in the Javadoc of ScheduledThreadPoolExecutor: "By default, such a cancelled task is not automatically removed from the work queue until its delay elapses. While this enables further inspection and monitoring, it may also cause unbounded retention of cancelled tasks. To avoid this, set setRemoveOnCancelPolicy to true, which causes tasks to be immediately removed from the work queue at time of cancellation." Issue: SPR-11918 --- .../config/WebSocketNamespaceUtils.java | 19 ++++++++++++++++++- .../WebSocketConfigurationSupport.java | 15 ++++++++++++++- ...cketMessageBrokerConfigurationSupport.java | 14 +++++++++++++- ...essageBrokerBeanDefinitionParserTests.java | 1 + ...essageBrokerConfigurationSupportTests.java | 6 ++++-- 5 files changed, 50 insertions(+), 5 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java index 18f1481d7c..0640c3de9c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java @@ -31,6 +31,11 @@ import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsSe import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.ThreadFactory; + /** * Provides utility methods for parsing common WebSocket XML namespace elements. * @@ -135,7 +140,7 @@ class WebSocketNamespaceUtils { ParserContext parserContext, Object source) { if (!parserContext.getRegistry().containsBeanDefinition(schedulerName)) { - RootBeanDefinition taskSchedulerDef = new RootBeanDefinition(ThreadPoolTaskScheduler.class); + RootBeanDefinition taskSchedulerDef = new RootBeanDefinition(SockJsThreadPoolTaskScheduler.class); taskSchedulerDef.setSource(source); taskSchedulerDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); taskSchedulerDef.getPropertyValues().add("poolSize", Runtime.getRuntime().availableProcessors()); @@ -161,4 +166,16 @@ class WebSocketNamespaceUtils { return beans; } + + @SuppressWarnings("serial") + private static class SockJsThreadPoolTaskScheduler extends ThreadPoolTaskScheduler { + + @Override + protected ExecutorService initializeExecutor(ThreadFactory factory, RejectedExecutionHandler handler) { + ExecutorService service = super.initializeExecutor(factory, handler); + ((ScheduledThreadPoolExecutor) service).setRemoveOnCancelPolicy(true); + return service; + } + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java index 6fcb7e49d9..f8769842b8 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java @@ -21,6 +21,11 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.handler.AbstractHandlerMapping; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.ThreadFactory; + /** * Configuration support for WebSocket request handling. * @@ -59,7 +64,15 @@ public class WebSocketConfigurationSupport { */ @Bean public ThreadPoolTaskScheduler defaultSockJsTaskScheduler() { - ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); + @SuppressWarnings("serial") + ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler() { + @Override + protected ExecutorService initializeExecutor(ThreadFactory factory, RejectedExecutionHandler handler) { + ExecutorService service = super.initializeExecutor(factory, handler); + ((ScheduledThreadPoolExecutor) service).setRemoveOnCancelPolicy(true); + return service; + } + }; scheduler.setThreadNamePrefix("SockJS-"); return scheduler; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java index 8633dc9539..a901da1082 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java @@ -17,6 +17,10 @@ package org.springframework.web.socket.config.annotation; import java.util.Collections; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.ThreadFactory; import org.springframework.beans.factory.config.CustomScopeConfigurer; import org.springframework.context.annotation.Bean; @@ -96,7 +100,15 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac */ @Bean public ThreadPoolTaskScheduler messageBrokerSockJsTaskScheduler() { - ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); + @SuppressWarnings("serial") + ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler() { + @Override + protected ExecutorService initializeExecutor(ThreadFactory factory, RejectedExecutionHandler handler) { + ExecutorService service = super.initializeExecutor(factory, handler); + ((ScheduledThreadPoolExecutor) service).setRemoveOnCancelPolicy(true); + return service; + } + }; scheduler.setPoolSize(Runtime.getRuntime().availableProcessors()); scheduler.setThreadNamePrefix("MessageBrokerSockJS-"); return scheduler; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index 67556f61d4..dfc73941cd 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -140,6 +140,7 @@ public class MessageBrokerBeanDefinitionParserTests { ThreadPoolTaskScheduler scheduler = (ThreadPoolTaskScheduler) defaultSockJsService.getTaskScheduler(); assertEquals(Runtime.getRuntime().availableProcessors(), scheduler.getScheduledThreadPoolExecutor().getCorePoolSize()); + assertTrue(scheduler.getScheduledThreadPoolExecutor().getRemoveOnCancelPolicy()); UserSessionRegistry userSessionRegistry = this.appContext.getBean(UserSessionRegistry.class); assertNotNull(userSessionRegistry); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java index 16d64a0577..0f16354d22 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ScheduledThreadPoolExecutor; import org.junit.Before; import org.junit.Test; @@ -128,8 +129,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests { ThreadPoolTaskScheduler taskScheduler = this.config.getBean("messageBrokerSockJsTaskScheduler", ThreadPoolTaskScheduler.class); - assertEquals(Runtime.getRuntime().availableProcessors(), - taskScheduler.getScheduledThreadPoolExecutor().getCorePoolSize()); + ScheduledThreadPoolExecutor executor = taskScheduler.getScheduledThreadPoolExecutor(); + assertEquals(Runtime.getRuntime().availableProcessors(), executor.getCorePoolSize()); + assertTrue(executor.getRemoveOnCancelPolicy()); } From 1e342e4dbff7d5edef23356daa304e05b40d4147 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 25 Jun 2014 08:40:27 -0400 Subject: [PATCH 7/7] Use SimpleAsyncTaskExecutor in WebSocketClient impls JettyWebsocketClient and StandardWebSocketClient now use the SimpleAsyncTaskExecutor by default to ensure non-blocking connect. --- .../web/socket/client/jetty/JettyWebSocketClient.java | 8 +++++--- .../socket/client/standard/StandardWebSocketClient.java | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java index 3bc7a965ec..485c920d52 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java @@ -28,6 +28,7 @@ import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.WebSocketClient; import org.springframework.context.SmartLifecycle; import org.springframework.core.task.AsyncListenableTaskExecutor; +import org.springframework.core.task.SimpleAsyncTaskExecutor; import org.springframework.core.task.TaskExecutor; import org.springframework.http.HttpHeaders; import org.springframework.util.concurrent.ListenableFuture; @@ -59,7 +60,7 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma private final Object lifecycleMonitor = new Object(); - private AsyncListenableTaskExecutor taskExecutor; + private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); /** @@ -81,9 +82,10 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma /** * Set an {@link AsyncListenableTaskExecutor} to use when opening connections. - * - *

If this property is not configured, calls to any of the + * If this property is set to {@code null}, calls to any of the * {@code doHandshake} methods will block until the connection is established. + * + *

By default an instance of {@code SimpleAsyncTaskExecutor} is used. */ public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) { this.taskExecutor = taskExecutor; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java index 001a790542..2addd49ed2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java @@ -34,6 +34,7 @@ import javax.websocket.HandshakeResponse; import javax.websocket.WebSocketContainer; import org.springframework.core.task.AsyncListenableTaskExecutor; +import org.springframework.core.task.SimpleAsyncTaskExecutor; import org.springframework.core.task.TaskExecutor; import org.springframework.http.HttpHeaders; import org.springframework.lang.UsesJava7; @@ -59,7 +60,7 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { private final WebSocketContainer webSocketContainer; - private AsyncListenableTaskExecutor taskExecutor; + private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); /** @@ -86,9 +87,10 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { /** * Set an {@link AsyncListenableTaskExecutor} to use when opening connections. - * - *

If this property is not configured, calls to any of the + * If this property is set to {@code null}, calls to any of the * {@code doHandshake} methods will block until the connection is established. + * + *

By default an instance of {@code SimpleAsyncTaskExecutor} is used. */ public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) { this.taskExecutor = taskExecutor;