diff --git a/build.gradle b/build.gradle index 6579189eef..14c58a4df1 100644 --- a/build.gradle +++ b/build.gradle @@ -670,6 +670,7 @@ project("spring-websocket") { exclude group: "javax.servlet", module: "javax.servlet" } optional("org.eclipse.jetty.websocket:websocket-client:${jettyVersion}") + optional("org.eclipse.jetty:jetty-client:${jettyVersion}") optional("io.undertow:undertow-core:1.0.15.Final") optional("io.undertow:undertow-servlet:1.0.15.Final") { exclude group: "org.jboss.spec.javax.servlet", module: "jboss-servlet-api_3.1_spec" diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java b/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java index 364a9597b7..9733e7f2a5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/CloseStatus.java @@ -217,7 +217,7 @@ public final class CloseStatus { @Override public String toString() { - return "CloseStatus [code=" + this.code + ", reason=" + this.reason + "]"; + return "CloseStatus[code=" + this.code + ", reason=" + this.reason + "]"; } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java index ea08db729a..dbec02cd99 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/AbstractWebSocketClient.java @@ -73,10 +73,7 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { WebSocketHttpHeaders headers, URI uri) { Assert.notNull(webSocketHandler, "webSocketHandler must not be null"); - Assert.notNull(uri, "uri must not be null"); - - String scheme = uri.getScheme(); - Assert.isTrue(((scheme != null) && ("ws".equals(scheme) || "wss".equals(scheme))), "Invalid scheme: " + scheme); + assertUri(uri); if (logger.isDebugEnabled()) { logger.debug("Connecting to " + uri); @@ -101,6 +98,12 @@ public abstract class AbstractWebSocketClient implements WebSocketClient { Collections.emptyMap()); } + protected void assertUri(URI uri) { + Assert.notNull(uri, "uri must not be null"); + String scheme = uri.getScheme(); + Assert.isTrue(scheme != null && ("ws".equals(scheme) || "wss".equals(scheme)), "Invalid scheme: " + scheme); + } + /** * Perform the actual handshake to establish a connection to the server. * diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java index 3bc7a965ec..485c920d52 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/jetty/JettyWebSocketClient.java @@ -28,6 +28,7 @@ import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.WebSocketClient; import org.springframework.context.SmartLifecycle; import org.springframework.core.task.AsyncListenableTaskExecutor; +import org.springframework.core.task.SimpleAsyncTaskExecutor; import org.springframework.core.task.TaskExecutor; import org.springframework.http.HttpHeaders; import org.springframework.util.concurrent.ListenableFuture; @@ -59,7 +60,7 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma private final Object lifecycleMonitor = new Object(); - private AsyncListenableTaskExecutor taskExecutor; + private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); /** @@ -81,9 +82,10 @@ public class JettyWebSocketClient extends AbstractWebSocketClient implements Sma /** * Set an {@link AsyncListenableTaskExecutor} to use when opening connections. - * - *

If this property is not configured, calls to any of the + * If this property is set to {@code null}, calls to any of the * {@code doHandshake} methods will block until the connection is established. + * + *

By default an instance of {@code SimpleAsyncTaskExecutor} is used. */ public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) { this.taskExecutor = taskExecutor; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java index 001a790542..2addd49ed2 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java @@ -34,6 +34,7 @@ import javax.websocket.HandshakeResponse; import javax.websocket.WebSocketContainer; import org.springframework.core.task.AsyncListenableTaskExecutor; +import org.springframework.core.task.SimpleAsyncTaskExecutor; import org.springframework.core.task.TaskExecutor; import org.springframework.http.HttpHeaders; import org.springframework.lang.UsesJava7; @@ -59,7 +60,7 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { private final WebSocketContainer webSocketContainer; - private AsyncListenableTaskExecutor taskExecutor; + private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); /** @@ -86,9 +87,10 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { /** * Set an {@link AsyncListenableTaskExecutor} to use when opening connections. - * - *

If this property is not configured, calls to any of the + * If this property is set to {@code null}, calls to any of the * {@code doHandshake} methods will block until the connection is established. + * + *

By default an instance of {@code SimpleAsyncTaskExecutor} is used. */ public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) { this.taskExecutor = taskExecutor; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java index 18f1481d7c..0640c3de9c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java @@ -31,6 +31,11 @@ import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsSe import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.ThreadFactory; + /** * Provides utility methods for parsing common WebSocket XML namespace elements. * @@ -135,7 +140,7 @@ class WebSocketNamespaceUtils { ParserContext parserContext, Object source) { if (!parserContext.getRegistry().containsBeanDefinition(schedulerName)) { - RootBeanDefinition taskSchedulerDef = new RootBeanDefinition(ThreadPoolTaskScheduler.class); + RootBeanDefinition taskSchedulerDef = new RootBeanDefinition(SockJsThreadPoolTaskScheduler.class); taskSchedulerDef.setSource(source); taskSchedulerDef.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); taskSchedulerDef.getPropertyValues().add("poolSize", Runtime.getRuntime().availableProcessors()); @@ -161,4 +166,16 @@ class WebSocketNamespaceUtils { return beans; } + + @SuppressWarnings("serial") + private static class SockJsThreadPoolTaskScheduler extends ThreadPoolTaskScheduler { + + @Override + protected ExecutorService initializeExecutor(ThreadFactory factory, RejectedExecutionHandler handler) { + ExecutorService service = super.initializeExecutor(factory, handler); + ((ScheduledThreadPoolExecutor) service).setRemoveOnCancelPolicy(true); + return service; + } + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java index 6fcb7e49d9..f8769842b8 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java @@ -21,6 +21,11 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.handler.AbstractHandlerMapping; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.ThreadFactory; + /** * Configuration support for WebSocket request handling. * @@ -59,7 +64,15 @@ public class WebSocketConfigurationSupport { */ @Bean public ThreadPoolTaskScheduler defaultSockJsTaskScheduler() { - ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); + @SuppressWarnings("serial") + ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler() { + @Override + protected ExecutorService initializeExecutor(ThreadFactory factory, RejectedExecutionHandler handler) { + ExecutorService service = super.initializeExecutor(factory, handler); + ((ScheduledThreadPoolExecutor) service).setRemoveOnCancelPolicy(true); + return service; + } + }; scheduler.setThreadNamePrefix("SockJS-"); return scheduler; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java index 8633dc9539..a901da1082 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java @@ -17,6 +17,10 @@ package org.springframework.web.socket.config.annotation; import java.util.Collections; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.ThreadFactory; import org.springframework.beans.factory.config.CustomScopeConfigurer; import org.springframework.context.annotation.Bean; @@ -96,7 +100,15 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac */ @Bean public ThreadPoolTaskScheduler messageBrokerSockJsTaskScheduler() { - ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); + @SuppressWarnings("serial") + ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler() { + @Override + protected ExecutorService initializeExecutor(ThreadFactory factory, RejectedExecutionHandler handler) { + ExecutorService service = super.initializeExecutor(factory, handler); + ((ScheduledThreadPoolExecutor) service).setRemoveOnCancelPolicy(true); + return service; + } + }; scheduler.setPoolSize(Runtime.getRuntime().availableProcessors()); scheduler.setThreadNamePrefix("MessageBrokerSockJS-"); return scheduler; diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index c82ebf8928..417955a47d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -194,7 +194,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } } catch (Throwable ex) { - logger.error("Failed to parse WebSocket message to STOMP frame(s)", ex); + logger.error("Failed to parse WebSocket message to STOMP." + + "Sending STOMP ERROR to client, sessionId=" + session.getId(), ex); sendErrorMessage(session, ex); return; } @@ -232,7 +233,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } } catch (Throwable ex) { - logger.error("Terminating STOMP session due to failure to send message", ex); + logger.error("Parsed STOMP message but could not send it to to message channel. " + + "Sending STOMP ERROR to client, sessionId=" + session.getId(), ex); sendErrorMessage(session, ex); } } @@ -248,7 +250,6 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE } protected void sendErrorMessage(WebSocketSession session, Throwable error) { - StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.ERROR); headerAccessor.setMessage(error.getMessage()); byte[] bytes = this.stompEncoder.encode(headerAccessor.getMessageHeaders(), EMPTY_PAYLOAD); @@ -331,7 +332,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE throw ex; } catch (Throwable ex) { - sendErrorMessage(session, ex); + logger.error("Failed to send WebSocket message to client, sessionId=" + session.getId(), ex); + command = StompCommand.ERROR; } finally { if (StompCommand.ERROR.equals(command)) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractClientSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractClientSockJsSession.java new file mode 100644 index 0000000000..16217fa03c --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractClientSockJsSession.java @@ -0,0 +1,338 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; +import org.springframework.web.socket.sockjs.frame.SockJsFrameType; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; + +import java.io.IOException; +import java.net.URI; +import java.security.Principal; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Base class for SockJS client implementations of {@link WebSocketSession}. + * Provides processing of incoming SockJS message frames and delegates lifecycle + * events and messages to the (application) {@link WebSocketHandler}. + * Sub-classes implement actual send as well as disconnect logic. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public abstract class AbstractClientSockJsSession implements WebSocketSession { + + protected final Log logger = LogFactory.getLog(getClass()); + + + private final TransportRequest request; + + private final WebSocketHandler webSocketHandler; + + private final SettableListenableFuture connectFuture; + + + private final Map attributes = new ConcurrentHashMap(); + + private volatile State state = State.NEW; + + private volatile CloseStatus closeStatus; + + + protected AbstractClientSockJsSession(TransportRequest request, WebSocketHandler handler, + SettableListenableFuture connectFuture) { + + Assert.notNull(request, "'request' is required"); + Assert.notNull(handler, "'handler' is required"); + Assert.notNull(connectFuture, "'connectFuture' is required"); + this.request = request; + this.webSocketHandler = handler; + this.connectFuture = connectFuture; + } + + + @Override + public String getId() { + return this.request.getSockJsUrlInfo().getSessionId(); + } + + @Override + public URI getUri() { + return this.request.getSockJsUrlInfo().getSockJsUrl(); + } + + @Override + public HttpHeaders getHandshakeHeaders() { + return this.request.getHandshakeHeaders(); + } + + @Override + public Map getAttributes() { + return this.attributes; + } + + @Override + public Principal getPrincipal() { + return this.request.getUser(); + } + + public SockJsMessageCodec getMessageCodec() { + return this.request.getMessageCodec(); + } + + public WebSocketHandler getWebSocketHandler() { + return this.webSocketHandler; + } + + /** + * Return a timeout cleanup task to invoke if the SockJS sessions is not + * fully established within the retransmission timeout period calculated in + * {@code SockJsRequest} based on the duration of the initial SockJS "Info" + * request. + */ + Runnable getTimeoutTask() { + return new Runnable() { + @Override + public void run() { + closeInternal(new CloseStatus(2007, "Transport timed out")); + } + }; + } + + @Override + public boolean isOpen() { + return State.OPEN.equals(this.state); + } + + public boolean isDisconnected() { + return (State.CLOSING.equals(this.state) || State.CLOSED.equals(this.state)); + } + + @Override + public final void sendMessage(WebSocketMessage message) throws IOException { + Assert.state(State.OPEN.equals(this.state), this + " is not open, current state=" + this.state); + Assert.isInstanceOf(TextMessage.class, message, this + " supports text messages only."); + String payload = ((TextMessage) message).getPayload(); + payload = getMessageCodec().encode(new String[] { payload }); + payload = payload.substring(1); // the client-side doesn't need message framing (letter "a") + message = new TextMessage(payload); + if (logger.isTraceEnabled()) { + logger.trace("Sending message " + message + " in " + this); + } + sendInternal((TextMessage) message); + } + + protected abstract void sendInternal(TextMessage textMessage) throws IOException; + + @Override + public final void close() throws IOException { + close(CloseStatus.NORMAL); + } + + @Override + public final void close(CloseStatus status) { + Assert.isTrue(status != null && isUserSetStatus(status), "Invalid close status: " + status); + if (logger.isInfoEnabled()) { + logger.info("Closing session with " + status + " in " + this); + } + closeInternal(status); + } + + private boolean isUserSetStatus(CloseStatus status) { + return (status.getCode() == 1000 || (status.getCode() >= 3000 && status.getCode() <= 4999)); + } + + protected void closeInternal(CloseStatus status) { + if (this.state == null) { + logger.warn("Ignoring close since connect() was never invoked"); + return; + } + if (State.CLOSING.equals(this.state) || State.CLOSED.equals(this.state)) { + logger.debug("Ignoring close (already closing or closed), current state=" + this.state); + return; + } + this.state = State.CLOSING; + this.closeStatus = status; + try { + disconnect(status); + } + catch (Throwable ex) { + if (logger.isErrorEnabled()) { + logger.error("Failed to close " + this, ex); + } + } + } + + protected abstract void disconnect(CloseStatus status) throws IOException; + + public void handleFrame(String payload) { + SockJsFrame frame = new SockJsFrame(payload); + if (SockJsFrameType.OPEN.equals(frame.getType())) { + handleOpenFrame(); + } + else if (SockJsFrameType.MESSAGE.equals(frame.getType())) { + handleMessageFrame(frame); + } + else if (SockJsFrameType.CLOSE.equals(frame.getType())) { + handleCloseFrame(frame); + } + else if (SockJsFrameType.HEARTBEAT.equals(frame.getType())) { + if (logger.isTraceEnabled()) { + logger.trace("Received heartbeat in " + this); + } + } + else { + // should never happen + throw new IllegalStateException("Unknown SockJS frame type " + frame + " in " + this); + } + } + + private void handleOpenFrame() { + if (logger.isInfoEnabled()) { + logger.info("Processing SockJS open frame in " + this); + } + if (State.NEW.equals(state)) { + this.state = State.OPEN; + try { + this.webSocketHandler.afterConnectionEstablished(this); + this.connectFuture.set(this); + } + catch (Throwable ex) { + if (logger.isErrorEnabled()) { + Class type = this.webSocketHandler.getClass(); + logger.error(type + ".afterConnectionEstablished threw exception in " + this, ex); + } + } + } + else { + if (logger.isDebugEnabled()) { + logger.debug("Open frame received in " + getId() + " but we're not" + + "connecting (current state=" + this.state + "). The server might " + + "have been restarted and lost track of the session."); + } + closeInternal(new CloseStatus(1006, "Server lost session")); + } + } + + private void handleMessageFrame(SockJsFrame frame) { + if (!isOpen()) { + if (logger.isWarnEnabled()) { + logger.warn("Ignoring received message due to state=" + this.state + " in " + this); + } + return; + } + String[] messages; + try { + messages = getMessageCodec().decode(frame.getFrameData()); + } + catch (IOException ex) { + if (logger.isErrorEnabled()) { + logger.error("Failed to decode data for SockJS \"message\" frame: " + frame + " in " + this, ex); + } + closeInternal(CloseStatus.BAD_DATA); + return; + } + if (logger.isTraceEnabled()) { + logger.trace("Processing SockJS message frame " + frame.getContent() + " in " + this); + } + for (String message : messages) { + try { + if (isOpen()) { + this.webSocketHandler.handleMessage(this, new TextMessage(message)); + } + } + catch (Throwable ex) { + Class type = this.webSocketHandler.getClass(); + logger.error(type + ".handleMessage threw an exception on " + frame + " in " + this, ex); + } + } + } + + private void handleCloseFrame(SockJsFrame frame) { + CloseStatus closeStatus = CloseStatus.NO_STATUS_CODE; + try { + String[] data = getMessageCodec().decode(frame.getFrameData()); + if (data.length == 2) { + closeStatus = new CloseStatus(Integer.valueOf(data[0]), data[1]); + } + if (logger.isInfoEnabled()) { + logger.info("Processing SockJS close frame with " + closeStatus + " in " + this); + } + } + catch (IOException ex) { + if (logger.isErrorEnabled()) { + logger.error("Failed to decode data for " + frame + " in " + this, ex); + } + } + closeInternal(closeStatus); + } + + public void handleTransportError(Throwable error) { + try { + if (logger.isErrorEnabled()) { + logger.error("Transport error in " + this, error); + } + this.webSocketHandler.handleTransportError(this, error); + } + catch (Exception ex) { + Class type = this.webSocketHandler.getClass(); + if (logger.isErrorEnabled()) { + logger.error(type + ".handleTransportError threw an exception", ex); + } + } + } + + public void afterTransportClosed(CloseStatus closeStatus) { + this.closeStatus = (this.closeStatus != null ? this.closeStatus : closeStatus); + Assert.state(this.closeStatus != null, "CloseStatus not available"); + + if (logger.isInfoEnabled()) { + logger.info("Transport closed with " + this.closeStatus + " in " + this); + } + + this.state = State.CLOSED; + try { + this.webSocketHandler.afterConnectionClosed(this, this.closeStatus); + } + catch (Exception ex) { + if (logger.isErrorEnabled()) { + Class type = this.webSocketHandler.getClass(); + logger.error(type + ".afterConnectionClosed threw an exception", ex); + } + } + } + + @Override + public String toString() { + return getClass().getSimpleName() + "[id='" + getId() + ", url=" + getUri() + "]"; + } + + + private enum State { NEW, OPEN, CLOSING, CLOSED } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java new file mode 100644 index 0000000000..9ef2218528 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java @@ -0,0 +1,163 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; + +import java.net.URI; + +/** + * Abstract base class for XHR transport implementations to extend. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public abstract class AbstractXhrTransport implements XhrTransport { + + protected static final String PRELUDE; + + static { + byte[] bytes = new byte[2048]; + for (int i = 0; i < bytes.length; i++) { + bytes[i] = 'h'; + } + PRELUDE = new String(bytes, SockJsFrame.CHARSET); + } + + protected Log logger = LogFactory.getLog(getClass()); + + private boolean xhrStreamingDisabled; + + private HttpHeaders requestHeaders = new HttpHeaders(); + + private HttpHeaders xhrSendRequestHeaders = new HttpHeaders(); + + + /** + * Whether to attempt to connect with "xhr_streaming" first before trying + * with "xhr" next, see {@link XhrTransport#isXhrStreamingDisabled()}. + * + *

By default this property is set to {@code false} which means both + * "xhr_streaming" and "xhr" will be tried. + */ + public void setXhrStreamingDisabled(boolean disabled) { + this.xhrStreamingDisabled = disabled; + } + + public boolean isXhrStreamingDisabled() { + return this.xhrStreamingDisabled; + } + + /** + * Configure headers to be added to every executed HTTP request. + * @param requestHeaders the headers to add to requests + */ + public void setRequestHeaders(HttpHeaders requestHeaders) { + this.requestHeaders.clear(); + this.xhrSendRequestHeaders.clear(); + if (requestHeaders != null) { + this.requestHeaders.putAll(requestHeaders); + this.xhrSendRequestHeaders.putAll(requestHeaders); + this.xhrSendRequestHeaders.setContentType(MediaType.APPLICATION_JSON); + } + } + + public HttpHeaders getRequestHeaders() { + return this.requestHeaders; + } + + @Override + public String executeInfoRequest(URI infoUrl) { + if (logger.isDebugEnabled()) { + logger.debug("Executing SockJS Info request, url=" + infoUrl); + } + ResponseEntity response = executeInfoRequestInternal(infoUrl); + if (response.getStatusCode() != HttpStatus.OK) { + if (logger.isErrorEnabled()) { + logger.error("SockJS Info request (url=" + infoUrl + ") failed: " + response); + } + throw new HttpServerErrorException(response.getStatusCode()); + } + if (logger.isDebugEnabled()) { + logger.debug("SockJS Info request (url=" + infoUrl + ") response: " + response); + } + return response.getBody(); + } + + protected abstract ResponseEntity executeInfoRequestInternal(URI infoUrl); + + @Override + public void executeSendRequest(URI url, TextMessage message) { + if (logger.isDebugEnabled()) { + logger.debug("Starting XHR send, url=" + url); + } + ResponseEntity response = executeSendRequestInternal(url, this.xhrSendRequestHeaders, message); + if (response.getStatusCode() != HttpStatus.NO_CONTENT) { + if (logger.isErrorEnabled()) { + logger.error("XHR send request (url=" + url + ") failed: " + response); + } + throw new HttpServerErrorException(response.getStatusCode()); + } + if (logger.isDebugEnabled()) { + logger.debug("XHR send request (url=" + url + ") response: " + response); + } + } + + protected abstract ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message); + + @Override + public ListenableFuture connect(TransportRequest request, WebSocketHandler handler) { + SettableListenableFuture connectFuture = new SettableListenableFuture(); + XhrClientSockJsSession session = new XhrClientSockJsSession(request, handler, this, connectFuture); + request.addTimeoutTask(session.getTimeoutTask()); + + URI receiveUrl = request.getTransportUrl(); + if (logger.isDebugEnabled()) { + logger.debug("Opening XHR session, receive url=" + receiveUrl); + } + + HttpHeaders handshakeHeaders = new HttpHeaders(); + handshakeHeaders.putAll(request.getHandshakeHeaders()); + handshakeHeaders.putAll(getRequestHeaders()); + + connectInternal(request, handler, receiveUrl, handshakeHeaders, session, connectFuture); + return connectFuture; + } + + protected abstract void connectInternal(TransportRequest request, WebSocketHandler handler, + URI receiveUrl, HttpHeaders handshakeHeaders, XhrClientSockJsSession session, + SettableListenableFuture connectFuture); + + + @Override + public String toString() { + return getClass().getSimpleName(); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java new file mode 100644 index 0000000000..fd1e630f77 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java @@ -0,0 +1,238 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.SockJsTransportFailureException; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.URI; +import java.security.Principal; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A default implementation of + * {@link org.springframework.web.socket.sockjs.client.TransportRequest + * TransportRequest}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +class DefaultTransportRequest implements TransportRequest { + + private static Log logger = LogFactory.getLog(DefaultTransportRequest.class); + + + private final SockJsUrlInfo sockJsUrlInfo; + + private final HttpHeaders handshakeHeaders; + + private final Transport transport; + + private final TransportType serverTransportType; + + private SockJsMessageCodec codec; + + private Principal user; + + private long timeoutValue; + + private TaskScheduler timeoutScheduler; + + private final List timeoutTasks = new ArrayList(); + + private DefaultTransportRequest fallbackRequest; + + + public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo, HttpHeaders handshakeHeaders, + Transport transport, TransportType serverTransportType, SockJsMessageCodec codec) { + + Assert.notNull(sockJsUrlInfo, "'sockJsUrlInfo' is required"); + Assert.notNull(transport, "'transport' is required"); + Assert.notNull(serverTransportType, "'transportType' is required"); + Assert.notNull(codec, "'codec' is required"); + this.sockJsUrlInfo = sockJsUrlInfo; + this.handshakeHeaders = (handshakeHeaders != null ? handshakeHeaders : new HttpHeaders()); + this.transport = transport; + this.serverTransportType = serverTransportType; + this.codec = codec; + } + + + @Override + public SockJsUrlInfo getSockJsUrlInfo() { + return this.sockJsUrlInfo; + } + + @Override + public HttpHeaders getHandshakeHeaders() { + return this.handshakeHeaders; + } + + @Override + public URI getTransportUrl() { + return this.sockJsUrlInfo.getTransportUrl(this.serverTransportType); + } + + public void setUser(Principal user) { + this.user = user; + } + + @Override + public Principal getUser() { + return this.user; + } + + @Override + public SockJsMessageCodec getMessageCodec() { + return this.codec; + } + + public void setTimeoutValue(long timeoutValue) { + this.timeoutValue = timeoutValue; + } + + public void setTimeoutScheduler(TaskScheduler scheduler) { + this.timeoutScheduler = scheduler; + } + + @Override + public void addTimeoutTask(Runnable runnable) { + this.timeoutTasks.add(runnable); + } + + public void setFallbackRequest(DefaultTransportRequest fallbackRequest) { + this.fallbackRequest = fallbackRequest; + } + + + public void connect(WebSocketHandler handler, SettableListenableFuture future) { + if (logger.isDebugEnabled()) { + logger.debug("Starting " + this); + } + ConnectCallback connectCallback = new ConnectCallback(handler, future); + scheduleConnectTimeoutTask(connectCallback); + this.transport.connect(this, handler).addCallback(connectCallback); + } + + + private void scheduleConnectTimeoutTask(ConnectCallback connectHandler) { + if (this.timeoutScheduler != null) { + if (logger.isDebugEnabled()) { + logger.debug("Scheduling connect to time out after " + this.timeoutValue + " milliseconds"); + } + Date timeoutDate = new Date(System.currentTimeMillis() + this.timeoutValue); + this.timeoutScheduler.schedule(connectHandler, timeoutDate); + } + else if (logger.isDebugEnabled()) { + logger.debug("Connect timeout task not scheduled. Is SockJsClient configured with a TaskScheduler?"); + } + } + + + @Override + public String toString() { + return "TransportRequest[url=" + getTransportUrl() + "]"; + } + + + /** + * Updates the given (global) future based success or failure to connect for + * the entire SockJS request regardless of which transport actually managed + * to connect. Also implements {@code Runnable} to handle a scheduled timeout + * callback. + */ + private class ConnectCallback implements ListenableFutureCallback, Runnable { + + private final WebSocketHandler handler; + + private final SettableListenableFuture future; + + private final AtomicBoolean handled = new AtomicBoolean(false); + + + public ConnectCallback(WebSocketHandler handler, SettableListenableFuture future) { + this.handler = handler; + this.future = future; + } + + + @Override + public void onSuccess(WebSocketSession session) { + if (this.handled.compareAndSet(false, true)) { + this.future.set(session); + } + else { + logger.error("Connect success/failure already handled for " + DefaultTransportRequest.this); + } + } + + @Override + public void onFailure(Throwable failure) { + handleFailure(failure, false); + } + + @Override + public void run() { + handleFailure(null, true); + } + + private void handleFailure(Throwable failure, boolean isTimeoutFailure) { + if (this.handled.compareAndSet(false, true)) { + if (isTimeoutFailure) { + String message = "Connect timed out for " + DefaultTransportRequest.this; + logger.error(message); + failure = new SockJsTransportFailureException(message, getSockJsUrlInfo().getSessionId(), null); + } + if (fallbackRequest != null) { + logger.error(DefaultTransportRequest.this + " failed. Falling back on next transport.", failure); + fallbackRequest.connect(this.handler, this.future); + } + else { + logger.error("No more fallback transports after " + DefaultTransportRequest.this, failure); + this.future.setException(failure); + } + if (isTimeoutFailure) { + try { + for (Runnable runnable : timeoutTasks) { + runnable.run(); + } + } + catch (Throwable ex) { + logger.error("Transport failed to run timeout tasks for " + DefaultTransportRequest.this, ex); + } + } + } + else { + logger.error("Connect success/failure events already took place for " + + DefaultTransportRequest.this + ". Ignoring this additional failure event.", failure); + } + } + } +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java new file mode 100644 index 0000000000..ae8ba7bf01 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java @@ -0,0 +1,24 @@ +package org.springframework.web.socket.sockjs.client; + +import java.net.URI; + +/** + * A simple contract for executing the SockJS "Info" request before the SockJS + * session starts. The request is used to check server capabilities such as + * whether it permits use of the WebSocket transport. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public interface InfoReceiver { + + /** + * Perform an HTTP request to the SockJS "Info" URL. + * and return the resulting JSON response content, or raise an exception. + * + * @param infoUrl the URL to obtain SockJS server information from + * @return the body of the response + */ + String executeInfoRequest(URI infoUrl); + +} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java new file mode 100644 index 0000000000..a4f8e23870 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java @@ -0,0 +1,252 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.client.api.ContentResponse; +import org.eclipse.jetty.client.api.Request; +import org.eclipse.jetty.client.api.Response; +import org.eclipse.jetty.client.util.StringContentProvider; +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.http.HttpMethod; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.SockJsException; +import org.springframework.web.socket.sockjs.SockJsTransportFailureException; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; + +import java.io.ByteArrayOutputStream; +import java.net.URI; +import java.nio.ByteBuffer; +import java.util.Enumeration; + + +/** + * An XHR transport based on Jetty's {@link org.eclipse.jetty.client.HttpClient}. + * + *

When used for testing purposes (e.g. load testing) the {@code HttpClient} + * properties must be set to allow a larger than usual number of connections and + * threads. For example: + * + *

+ * HttpClient httpClient = new HttpClient();
+ * httpClient.setMaxConnectionsPerDestination(1000);
+ * httpClient.setExecutor(new QueuedThreadPool(500));
+ * 
+ * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class JettyXhrTransport extends AbstractXhrTransport implements XhrTransport { + + private final HttpClient httpClient; + + + public JettyXhrTransport(HttpClient httpClient) { + Assert.notNull(httpClient, "'httpClient' is required"); + this.httpClient = httpClient; + } + + + public HttpClient getHttpClient() { + return this.httpClient; + } + + @Override + protected ResponseEntity executeInfoRequestInternal(URI infoUrl) { + return executeRequest(infoUrl, HttpMethod.GET, getRequestHeaders(), null); + } + + @Override + public ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) { + return executeRequest(url, HttpMethod.POST, headers, message.getPayload()); + } + + protected ResponseEntity executeRequest(URI url, HttpMethod method, HttpHeaders headers, String body) { + Request httpRequest = this.httpClient.newRequest(url).method(method); + addHttpHeaders(httpRequest, headers); + if (body != null) { + httpRequest.content(new StringContentProvider(body)); + } + ContentResponse response; + try { + response = httpRequest.send(); + } + catch (Exception ex) { + throw new SockJsTransportFailureException("Failed to execute request to " + url, null, ex); + } + HttpStatus status = HttpStatus.valueOf(response.getStatus()); + HttpHeaders responseHeaders = toHttpHeaders(response.getHeaders()); + return (response.getContent() != null ? + new ResponseEntity(response.getContentAsString(), responseHeaders, status) : + new ResponseEntity(responseHeaders, status)); + } + + private static void addHttpHeaders(Request request, HttpHeaders headers) { + for (String name : headers.keySet()) { + for (String value : headers.get(name)) { + request.header(name, value); + } + } + } + + private static HttpHeaders toHttpHeaders(HttpFields httpFields) { + HttpHeaders responseHeaders = new HttpHeaders(); + Enumeration names = httpFields.getFieldNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + Enumeration values = httpFields.getValues(name); + while (values.hasMoreElements()) { + String value = values.nextElement(); + responseHeaders.add(name, value); + } + } + return responseHeaders; + } + + @Override + protected void connectInternal(TransportRequest request, WebSocketHandler handler, + URI url, HttpHeaders handshakeHeaders, XhrClientSockJsSession session, + SettableListenableFuture connectFuture) { + + SockJsResponseListener listener = new SockJsResponseListener(url, getRequestHeaders(), session, connectFuture); + executeReceiveRequest(url, handshakeHeaders, listener); + } + + private void executeReceiveRequest(URI url, HttpHeaders headers, SockJsResponseListener listener) { + if (logger.isDebugEnabled()) { + logger.debug("Starting XHR receive request, url=" + url); + } + Request httpRequest = this.httpClient.newRequest(url).method(HttpMethod.POST); + addHttpHeaders(httpRequest, headers); + httpRequest.send(listener); + } + + + /** + * Splits the body of an HTTP response into SockJS frames and delegates those + * to an {@link XhrClientSockJsSession}. + */ + private class SockJsResponseListener extends Response.Listener.Adapter { + + private final URI transportUrl; + + private final HttpHeaders receiveHeaders; + + private final XhrClientSockJsSession sockJsSession; + + private final SettableListenableFuture connectFuture; + + private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + + + public SockJsResponseListener(URI url, HttpHeaders headers, XhrClientSockJsSession sockJsSession, + SettableListenableFuture connectFuture) { + + this.transportUrl = url; + this.receiveHeaders = headers; + this.connectFuture = connectFuture; + this.sockJsSession = sockJsSession; + } + + + @Override + public void onBegin(Response response) { + if (response.getStatus() != 200) { + HttpStatus status = HttpStatus.valueOf(response.getStatus()); + response.abort(new HttpServerErrorException(status, "Unexpected XHR receive status")); + } + } + + @Override + public void onHeaders(Response response) { + if (logger.isDebugEnabled()) { + // Convert to HttpHeaders to avoid "\n" + logger.debug("XHR receive headers: " + toHttpHeaders(response.getHeaders())); + } + } + + @Override + public void onContent(Response response, ByteBuffer buffer) { + while (true) { + if (this.sockJsSession.isDisconnected()) { + if (logger.isDebugEnabled()) { + logger.debug("SockJS sockJsSession closed. Closing ClientHttpResponse."); + } + response.abort(new SockJsException("Session closed.", this.sockJsSession.getId(), null)); + return; + } + if (buffer.remaining() == 0) { + break; + } + int b = buffer.get(); + if (b == '\n') { + handleFrame(); + } + else { + this.outputStream.write(b); + } + } + } + + private void handleFrame() { + byte[] bytes = this.outputStream.toByteArray(); + this.outputStream.reset(); + String content = new String(bytes, SockJsFrame.CHARSET); + if (logger.isTraceEnabled()) { + logger.trace("XHR content received: " + content); + } + if (!PRELUDE.equals(content)) { + this.sockJsSession.handleFrame(new String(bytes, SockJsFrame.CHARSET)); + } + } + + @Override + public void onSuccess(Response response) { + if (this.outputStream.size() > 0) { + handleFrame(); + } + if (logger.isDebugEnabled()) { + logger.debug("XHR receive request completed."); + } + executeReceiveRequest(this.transportUrl, this.receiveHeaders, this); + } + + @Override + public void onFailure(Response response, Throwable failure) { + if (connectFuture.setException(failure)) { + return; + } + if (this.sockJsSession.isDisconnected()) { + this.sockJsSession.afterTransportClosed(null); + } + else { + this.sockJsSession.handleTransportError(failure); + this.sockJsSession.afterTransportClosed(new CloseStatus(1006, failure.getMessage())); + } + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java new file mode 100644 index 0000000000..1f25be4009 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java @@ -0,0 +1,265 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.springframework.core.task.SimpleAsyncTaskExecutor; +import org.springframework.core.task.TaskExecutor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.util.Assert; +import org.springframework.util.StreamUtils; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.client.RequestCallback; +import org.springframework.web.client.ResponseExtractor; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; + +/** + * An {@code XhrTransport} implementation that uses a + * {@link org.springframework.web.client.RestTemplate RestTemplate}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class RestTemplateXhrTransport extends AbstractXhrTransport implements XhrTransport { + + private final RestOperations restTemplate; + + private TaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); + + + public RestTemplateXhrTransport() { + this(new RestTemplate()); + } + + public RestTemplateXhrTransport(RestOperations restTemplate) { + Assert.notNull(restTemplate, "'restTemplate' is required"); + this.restTemplate = restTemplate; + } + + + /** + * Return the configured {@code RestTemplate}. + */ + public RestOperations getRestTemplate() { + return this.restTemplate; + } + + /** + * Configure the {@code TaskExecutor} to use to execute XHR receive requests. + * + *

By default {@link org.springframework.core.task.SimpleAsyncTaskExecutor + * SimpleAsyncTaskExecutor} is configured which creates a new thread every + * time the transports connects. + * + * @param taskExecutor the task executor, cannot be {@code null} + */ + public void setTaskExecutor(TaskExecutor taskExecutor) { + Assert.notNull(this.taskExecutor); + this.taskExecutor = taskExecutor; + } + + /** + * Return the configured {@code TaskExecutor}. + */ + public TaskExecutor getTaskExecutor() { + return this.taskExecutor; + } + + + @Override + public ResponseEntity executeInfoRequestInternal(URI infoUrl) { + RequestCallback requestCallback = new XhrRequestCallback(getRequestHeaders()); + return this.restTemplate.execute(infoUrl, HttpMethod.GET, requestCallback, textExtractor); + } + + @Override + public ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) { + RequestCallback requestCallback = new XhrRequestCallback(headers, message.getPayload()); + return this.restTemplate.execute(url, HttpMethod.POST, requestCallback, textExtractor); + } + + @Override + protected void connectInternal(final TransportRequest request, final WebSocketHandler handler, + final URI receiveUrl, final HttpHeaders handshakeHeaders, final XhrClientSockJsSession session, + final SettableListenableFuture connectFuture) { + + getTaskExecutor().execute(new Runnable() { + @Override + public void run() { + XhrRequestCallback requestCallback = new XhrRequestCallback(handshakeHeaders); + XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(getRequestHeaders()); + XhrReceiveExtractor responseExtractor = new XhrReceiveExtractor(session); + while (true) { + if (session.isDisconnected()) { + session.afterTransportClosed(null); + break; + } + try { + if (logger.isDebugEnabled()) { + logger.debug("Starting XHR receive request, url=" + receiveUrl); + } + getRestTemplate().execute(receiveUrl, HttpMethod.POST, requestCallback, responseExtractor); + requestCallback = requestCallbackAfterHandshake; + } + catch (Throwable ex) { + if (!connectFuture.isDone()) { + connectFuture.setException(ex); + } + else { + session.handleTransportError(ex); + session.afterTransportClosed(new CloseStatus(1006, ex.getMessage())); + } + break; + } + } + } + }); + } + + + /** + * A RequestCallback to add the headers and (optionally) String content. + */ + private static class XhrRequestCallback implements RequestCallback { + + private final HttpHeaders headers; + + private final String body; + + + public XhrRequestCallback(HttpHeaders headers) { + this(headers, null); + } + + public XhrRequestCallback(HttpHeaders headers, String body) { + this.headers = headers; + this.body = body; + } + + + @Override + public void doWithRequest(ClientHttpRequest request) throws IOException { + if (this.headers != null) { + request.getHeaders().putAll(this.headers); + } + if (this.body != null) { + StreamUtils.copy(this.body, SockJsFrame.CHARSET, request.getBody()); + } + } + } + + /** + * A simple ResponseExtractor that reads the body into a String. + */ + private final static ResponseExtractor> textExtractor = + new ResponseExtractor>() { + + @Override + public ResponseEntity extractData(ClientHttpResponse response) throws IOException { + if (response.getBody() == null) { + return new ResponseEntity(response.getHeaders(), response.getStatusCode()); + } + else { + String body = StreamUtils.copyToString(response.getBody(), SockJsFrame.CHARSET); + return new ResponseEntity(body, response.getHeaders(), response.getStatusCode()); + } + } + }; + + /** + * Splits the body of an HTTP response into SockJS frames and delegates those + * to an {@link XhrClientSockJsSession}. + */ + private class XhrReceiveExtractor implements ResponseExtractor { + + private final XhrClientSockJsSession sockJsSession; + + + public XhrReceiveExtractor(XhrClientSockJsSession sockJsSession) { + this.sockJsSession = sockJsSession; + } + + + @Override + public Object extractData(ClientHttpResponse response) throws IOException { + if (!HttpStatus.OK.equals(response.getStatusCode())) { + throw new HttpServerErrorException(response.getStatusCode()); + } + if (logger.isDebugEnabled()) { + logger.debug("XHR receive headers: " + response.getHeaders()); + } + InputStream is = response.getBody(); + ByteArrayOutputStream os = new ByteArrayOutputStream(); + while (true) { + if (this.sockJsSession.isDisconnected()) { + if (logger.isDebugEnabled()) { + logger.debug("SockJS sockJsSession closed. Closing ClientHttpResponse."); + } + response.close(); + break; + } + int b = is.read(); + if (b == -1) { + if (os.size() > 0) { + handleFrame(os); + } + if (logger.isDebugEnabled()) { + logger.debug("XHR receive completed"); + } + break; + } + if (b == '\n') { + handleFrame(os); + } + else { + os.write(b); + } + } + return null; + } + + private void handleFrame(ByteArrayOutputStream os) { + byte[] bytes = os.toByteArray(); + os.reset(); + String content = new String(bytes, SockJsFrame.CHARSET); + if (logger.isTraceEnabled()) { + logger.trace("XHR receive content: " + content); + } + if (!PRELUDE.equals(content)) { + this.sockJsSession.handleFrame(new String(bytes, SockJsFrame.CHARSET)); + } + } + } + +} + diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java new file mode 100644 index 0000000000..5d8bef8cfe --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java @@ -0,0 +1,259 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.client.AbstractWebSocketClient; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.URI; +import java.security.Principal; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A SockJS implementation of + * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient} + * with HTTP-based fallback alternative simulating a WebSocket interaction. + * + * @author Rossen Stoyanchev + * @since 4.1 + * + * @see http://sockjs.org + * @see org.springframework.web.socket.sockjs.client.Transport + */ +public class SockJsClient extends AbstractWebSocketClient { + + private static final boolean jackson2Present = ClassUtils.isPresent( + "com.fasterxml.jackson.databind.ObjectMapper", SockJsClient.class.getClassLoader()); + + private static final Log logger = LogFactory.getLog(SockJsClient.class); + + + private final List transports; + + private InfoReceiver infoReceiver; + + private SockJsMessageCodec messageCodec; + + private TaskScheduler taskScheduler; + + private final Map infoCache = new ConcurrentHashMap(); + + + /** + * Create a {@code SockJsClient} with the given transports. + * @param transports the transports to use + */ + public SockJsClient(List transports) { + Assert.notEmpty(transports, "No transports provided"); + this.transports = new ArrayList(transports); + this.infoReceiver = initInfoReceiver(transports); + if (jackson2Present) { + this.messageCodec = new Jackson2SockJsMessageCodec(); + } + } + + private static InfoReceiver initInfoReceiver(List transports) { + for (Transport transport : transports) { + if (transport instanceof InfoReceiver) { + return ((InfoReceiver) transport); + } + } + return new RestTemplateXhrTransport(); + } + + + /** + * Configure the {@code InfoReceiver} to use to perform the SockJS "Info" + * request before the SockJS session starts. + * + *

By default this is initialized either by looking through the configured + * transports to find the first {@code XhrTransport} or by creating an + * instance of {@code RestTemplateXhrTransport}. + * + * @param infoReceiver the transport to use for the SockJS "Info" request + */ + public void setInfoReceiver(InfoReceiver infoReceiver) { + this.infoReceiver = infoReceiver; + } + + public InfoReceiver getInfoReceiver() { + return this.infoReceiver; + } + + /** + * Set the SockJsMessageCodec to use. + * + *

By default {@link org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec + * Jackson2SockJsMessageCodec} is used if Jackson is on the classpath. + * + * @param messageCodec the message messageCodec to use + */ + public void setMessageCodec(SockJsMessageCodec messageCodec) { + Assert.notNull(messageCodec, "'messageCodec' is required"); + this.messageCodec = messageCodec; + } + + public SockJsMessageCodec getMessageCodec() { + return this.messageCodec; + } + + /** + * Configure a {@code TaskScheduler} for scheduling a connect timeout task + * where the timeout value is calculated based on the duration of the initial + * SockJS info request. Having a connect timeout task is optional but can + * improve the speed with which the client falls back to alternative + * transport options. + * + *

By default no task scheduler is configured in which case it may take + * longer before a fallback transport can be used. + * + * @param taskScheduler the scheduler to use + */ + public void setTaskScheduler(TaskScheduler taskScheduler) { + this.taskScheduler = taskScheduler; + } + + public void clearServerInfoCache() { + this.infoCache.clear(); + } + + @Override + protected void assertUri(URI uri) { + Assert.notNull(uri, "uri must not be null"); + String scheme = uri.getScheme(); + Assert.isTrue(scheme != null && ("ws".equals(scheme) || "wss".equals(scheme) + || "http".equals(scheme) || "https".equals(scheme)), "Invalid scheme: " + scheme); + } + + @Override + protected ListenableFuture doHandshakeInternal(WebSocketHandler handler, + HttpHeaders handshakeHeaders, URI url, List protocols, + List extensions, Map attributes) { + + SettableListenableFuture connectFuture = new SettableListenableFuture(); + try { + SockJsUrlInfo sockJsUrlInfo = new SockJsUrlInfo(url); + ServerInfo serverInfo = getServerInfo(sockJsUrlInfo); + createFallbackChain(sockJsUrlInfo, handshakeHeaders, serverInfo).connect(handler, connectFuture); + } + catch (Throwable exception) { + if (logger.isErrorEnabled()) { + logger.error("Initial SockJS \"Info\" request to server failed, url=" + url, exception); + } + connectFuture.setException(exception); + } + return connectFuture; + } + + private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo) { + URI infoUrl = sockJsUrlInfo.getInfoUrl(); + ServerInfo info = this.infoCache.get(infoUrl); + if (info == null) { + long start = System.currentTimeMillis(); + String response = this.infoReceiver.executeInfoRequest(infoUrl); + long infoRequestTime = System.currentTimeMillis() - start; + info = new ServerInfo(response, infoRequestTime); + this.infoCache.put(infoUrl, info); + } + return info; + } + + private DefaultTransportRequest createFallbackChain(SockJsUrlInfo urlInfo, HttpHeaders headers, ServerInfo serverInfo) { + List requests = new ArrayList(this.transports.size()); + for (Transport transport : this.transports) { + if (transport instanceof XhrTransport) { + XhrTransport xhrTransport = (XhrTransport) transport; + if (!xhrTransport.isXhrStreamingDisabled()) { + addRequest(requests, urlInfo, headers, serverInfo, transport, TransportType.XHR_STREAMING); + } + addRequest(requests, urlInfo, headers, serverInfo, transport, TransportType.XHR); + } + else if (serverInfo.isWebSocketEnabled()) { + addRequest(requests, urlInfo, headers, serverInfo, transport, TransportType.WEBSOCKET); + } + } + Assert.notEmpty(requests, + "0 transports for request to " + urlInfo + " . Configured transports: " + + this.transports + ". SockJS server webSocketEnabled=" + serverInfo.isWebSocketEnabled()); + for (int i = 0; i < requests.size() - 1; i++) { + requests.get(i).setFallbackRequest(requests.get(i + 1)); + } + return requests.get(0); + } + + private void addRequest(List requests, SockJsUrlInfo info, HttpHeaders headers, + ServerInfo serverInfo, Transport transport, TransportType type) { + + DefaultTransportRequest request = new DefaultTransportRequest(info, headers, transport, type, getMessageCodec()); + request.setUser(getUser()); + if (this.taskScheduler != null) { + request.setTimeoutValue(serverInfo.getRetransmissionTimeout()); + request.setTimeoutScheduler(this.taskScheduler); + } + requests.add(request); + } + + /** + * Return the user to associate with the SockJS session and make available via + * {@link org.springframework.web.socket.WebSocketSession#getPrincipal() + * WebSocketSession#getPrincipal()}. + *

By default this method returns {@code null}. + * @return the user to associate with the session, possibly {@code null} + */ + protected Principal getUser() { + return null; + } + + + private static class ServerInfo { + + private final boolean webSocketEnabled; + + private final long responseTime; + + + private ServerInfo(String response, long responseTime) { + this.responseTime = responseTime; + this.webSocketEnabled = !response.matches(".*[\"']websocket[\"']\\s*:\\s*false.*"); + } + + public boolean isWebSocketEnabled() { + return this.webSocketEnabled; + } + + public long getRetransmissionTimeout() { + return (this.responseTime > 100 ? 4 * this.responseTime : this.responseTime + 300); + } + } + +} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfo.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfo.java new file mode 100644 index 0000000000..6530f0a62b --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfo.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.springframework.util.AlternativeJdkIdGenerator; +import org.springframework.util.IdGenerator; +import org.springframework.web.socket.sockjs.transport.TransportType; +import org.springframework.web.util.UriComponentsBuilder; + +import java.net.URI; +import java.util.UUID; + +/** + * Given the base URL to a SockJS server endpoint, also provides methods to + * generate and obtain session and a server id used for construct a transport URL. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class SockJsUrlInfo { + + private static final IdGenerator idGenerator = new AlternativeJdkIdGenerator(); + + + private final URI sockJsUrl; + + private String serverId; + + private String sessionId; + + private UUID uuid; + + + public SockJsUrlInfo(URI sockJsUrl) { + this.sockJsUrl = sockJsUrl; + } + + + public URI getSockJsUrl() { + return this.sockJsUrl; + } + + public String getServerId() { + if (this.serverId == null) { + this.serverId = String.valueOf(Math.abs(getUuid().getMostSignificantBits()) % 1000); + } + return this.serverId; + } + + public String getSessionId() { + if (this.sessionId == null) { + this.sessionId = getUuid().toString().replace("-",""); + } + return this.sessionId; + } + + protected UUID getUuid() { + if (this.uuid == null) { + this.uuid = idGenerator.generateId(); + } + return this.uuid; + } + + public URI getInfoUrl() { + return UriComponentsBuilder.fromUri(this.sockJsUrl) + .scheme(getScheme(TransportType.XHR)) + .pathSegment("info") + .build(true).toUri(); + } + + public URI getTransportUrl(TransportType transportType) { + return UriComponentsBuilder.fromUri(this.sockJsUrl) + .scheme(getScheme(transportType)) + .pathSegment(getServerId()) + .pathSegment(getSessionId()) + .pathSegment(transportType.toString()) + .build(true).toUri(); + } + + private String getScheme(TransportType transportType) { + String scheme = this.sockJsUrl.getScheme(); + if (TransportType.WEBSOCKET.equals(transportType)) { + if (!scheme.startsWith("ws")) { + scheme = ("https".equals(scheme) ? "wss" : "ws"); + } + } + else { + if (!scheme.startsWith("http")) { + scheme = ("wss".equals(scheme) ? "https" : "http"); + } + } + return scheme; + } + + + @Override + public String toString() { + return "SockJsUrlInfo[url=" + this.sockJsUrl + "]"; + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/Transport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/Transport.java new file mode 100644 index 0000000000..41d554a5ef --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/Transport.java @@ -0,0 +1,40 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.socket.sockjs.client; + + +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; + +/** + * A client-side implementation for a SockJS transport. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public interface Transport { + + /** + * Connect the transport. + * + * @param request the transport request. + * @param webSocketHandler the application handler to delegate lifecycle events to. + * @return a future to indicate success or failure to connect. + */ + ListenableFuture connect(TransportRequest request, WebSocketHandler webSocketHandler); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java new file mode 100644 index 0000000000..5e92b28296 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java @@ -0,0 +1,53 @@ +package org.springframework.web.socket.sockjs.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.URI; +import java.security.Principal; + +/** + * Represents a request to connect to a SockJS service using a specific + * Transport. A single SockJS request however may require falling back + * and therefore multiple TransportRequest instances. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public interface TransportRequest { + + /** + * Return information about the SockJS URL including server and session id.. + */ + SockJsUrlInfo getSockJsUrlInfo(); + + /** + * Return the headers to send with the connect request. + */ + HttpHeaders getHandshakeHeaders(); + + /** + * Return the transport URL for the given transport. + * For an {@link XhrTransport} this is the URL for receiving messages. + */ + URI getTransportUrl(); + + /** + * Return the user associated with the request, if any. + */ + Principal getUser(); + + /** + * Return the message codec to use for encoding SockJS messages. + */ + SockJsMessageCodec getMessageCodec(); + + /** + * Register a timeout cleanup task to invoke if the SockJS session is not + * fully established within the calculated retransmission timeout period. + */ + void addTimeoutTask(Runnable runnable); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketClientSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketClientSockJsSession.java new file mode 100644 index 0000000000..543a6a2c73 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketClientSockJsSession.java @@ -0,0 +1,136 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + + +import org.springframework.util.Assert; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.adapter.NativeWebSocketSession; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.List; + +/** + * An extension of {@link AbstractClientSockJsSession} wrapping and delegating + * to an actual WebSocket session. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class WebSocketClientSockJsSession extends AbstractClientSockJsSession implements NativeWebSocketSession { + + private WebSocketSession webSocketSession; + + + public WebSocketClientSockJsSession(TransportRequest request, WebSocketHandler handler, + SettableListenableFuture connectFuture) { + + super(request, handler, connectFuture); + } + + + @Override + public Object getNativeSession() { + return this.webSocketSession; + } + + @SuppressWarnings("unchecked") + @Override + public T getNativeSession(Class requiredType) { + if (requiredType != null) { + if (requiredType.isInstance(this.webSocketSession)) { + return (T) this.webSocketSession; + } + } + return null; + } + + @Override + public InetSocketAddress getLocalAddress() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getLocalAddress(); + } + + @Override + public InetSocketAddress getRemoteAddress() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getRemoteAddress(); + } + + @Override + public String getAcceptedProtocol() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getAcceptedProtocol(); + } + + @Override + public void setTextMessageSizeLimit(int messageSizeLimit) { + checkDelegateSessionInitialized(); + this.webSocketSession.setTextMessageSizeLimit(messageSizeLimit); + } + + @Override + public int getTextMessageSizeLimit() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getTextMessageSizeLimit(); + } + + @Override + public void setBinaryMessageSizeLimit(int messageSizeLimit) { + checkDelegateSessionInitialized(); + this.webSocketSession.setBinaryMessageSizeLimit(messageSizeLimit); + } + + @Override + public int getBinaryMessageSizeLimit() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getBinaryMessageSizeLimit(); + } + + @Override + public List getExtensions() { + checkDelegateSessionInitialized(); + return this.webSocketSession.getExtensions(); + } + + private void checkDelegateSessionInitialized() { + Assert.state(this.webSocketSession != null, "WebSocketSession not yet initialized"); + } + + public void initializeDelegateSession(WebSocketSession session) { + this.webSocketSession = session; + } + + @Override + protected void sendInternal(TextMessage textMessage) throws IOException { + this.webSocketSession.sendMessage(textMessage); + } + + @Override + protected void disconnect(CloseStatus status) throws IOException { + if (this.webSocketSession != null) { + this.webSocketSession.close(status); + } + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketTransport.java new file mode 100644 index 0000000000..ef36a96840 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/WebSocketTransport.java @@ -0,0 +1,129 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.util.Assert; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketHttpHeaders; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.client.WebSocketClient; +import org.springframework.web.socket.handler.TextWebSocketHandler; + +import java.net.URI; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A SockJS {@link Transport} that uses a + * {@link org.springframework.web.socket.client.WebSocketClient WebSocketClient}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class WebSocketTransport implements Transport { + + private static Log logger = LogFactory.getLog(WebSocketTransport.class); + + private final WebSocketClient webSocketClient; + + + public WebSocketTransport(WebSocketClient webSocketClient) { + Assert.notNull(webSocketClient, "'webSocketClient' is required"); + this.webSocketClient = webSocketClient; + } + + + /** + * Return the configured {@code WebSocketClient}. + */ + public WebSocketClient getWebSocketClient() { + return this.webSocketClient; + } + + @Override + public ListenableFuture connect(TransportRequest request, WebSocketHandler handler) { + final SettableListenableFuture future = new SettableListenableFuture(); + WebSocketClientSockJsSession session = new WebSocketClientSockJsSession(request, handler, future); + handler = new ClientSockJsWebSocketHandler(session); + request.addTimeoutTask(session.getTimeoutTask()); + + URI url = request.getTransportUrl(); + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(request.getHandshakeHeaders()); + if (logger.isDebugEnabled()) { + logger.debug("Opening WebSocket connection, url=" + url); + } + this.webSocketClient.doHandshake(handler, headers, url).addCallback( + new ListenableFutureCallback() { + @Override + public void onSuccess(WebSocketSession webSocketSession) { + // WebSocket session ready, SockJS Session not yet + } + @Override + public void onFailure(Throwable t) { + future.setException(t); + } + }); + return future; + } + + @Override + public String toString() { + return "WebSocketTransport[client=" + this.webSocketClient + "]"; + } + + + private static class ClientSockJsWebSocketHandler extends TextWebSocketHandler { + + private final WebSocketClientSockJsSession sockJsSession; + + private final AtomicInteger connectCount = new AtomicInteger(0); + + + private ClientSockJsWebSocketHandler(WebSocketClientSockJsSession session) { + Assert.notNull(session); + this.sockJsSession = session; + } + + @Override + public void afterConnectionEstablished(WebSocketSession webSocketSession) throws Exception { + Assert.isTrue(this.connectCount.compareAndSet(0, 1)); + this.sockJsSession.initializeDelegateSession(webSocketSession); + } + + @Override + public void handleTextMessage(WebSocketSession webSocketSession, TextMessage message) throws Exception { + this.sockJsSession.handleFrame(message.getPayload()); + } + + @Override + public void handleTransportError(WebSocketSession webSocketSession, Throwable ex) throws Exception { + this.sockJsSession.handleTransportError(ex); + } + + @Override + public void afterConnectionClosed(WebSocketSession webSocketSession, CloseStatus status) throws Exception { + this.sockJsSession.afterTransportClosed(status); + } + } + +} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java new file mode 100644 index 0000000000..92de51cea7 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java @@ -0,0 +1,111 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.springframework.util.Assert; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.util.List; + + +/** + * An extension of {@link AbstractClientSockJsSession} for use with HTTP + * transports simulating a WebSocket session. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class XhrClientSockJsSession extends AbstractClientSockJsSession { + + private final URI sendUrl; + + private final XhrTransport transport; + + private int textMessageSizeLimit = -1; + + private int binaryMessageSizeLimit = -1; + + + public XhrClientSockJsSession(TransportRequest request, WebSocketHandler handler, + XhrTransport transport, SettableListenableFuture connectFuture) { + + super(request, handler, connectFuture); + Assert.notNull(transport, "'restTemplate' is required"); + this.sendUrl = request.getSockJsUrlInfo().getTransportUrl(TransportType.XHR_SEND); + this.transport = transport; + } + + + @Override + public InetSocketAddress getLocalAddress() { + return null; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return new InetSocketAddress(getUri().getHost(), getUri().getPort()); + } + + @Override + public String getAcceptedProtocol() { + return null; + } + + @Override + public void setTextMessageSizeLimit(int messageSizeLimit) { + this.textMessageSizeLimit = messageSizeLimit; + } + + @Override + public int getTextMessageSizeLimit() { + return this.textMessageSizeLimit; + } + + @Override + public void setBinaryMessageSizeLimit(int messageSizeLimit) { + this.binaryMessageSizeLimit = -1; + } + + @Override + public int getBinaryMessageSizeLimit() { + return this.binaryMessageSizeLimit; + } + + @Override + public List getExtensions() { + return null; + } + + @Override + protected void sendInternal(TextMessage message) { + this.transport.executeSendRequest(this.sendUrl, message); + } + + @Override + protected void disconnect(CloseStatus status) { + // Nothing to do, XHR transports check if session is disconnected + } + +} \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java new file mode 100644 index 0000000000..726b202464 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java @@ -0,0 +1,40 @@ +package org.springframework.web.socket.sockjs.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.ResponseEntity; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; + +import java.net.URI; + +/** + * A SockJS {@link Transport} that uses HTTP requests to simulate a WebSocket + * interaction. The {@code connect} method of the base {@code Transport} interface + * is used to receive messages from the server while the + * {@link #executeSendRequest(java.net.URI, org.springframework.web.socket.TextMessage) + * executeSendRequest(URI, TextMessage)} method here is used to send messages. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public interface XhrTransport extends Transport, InfoReceiver { + + /** + * An {@code XhrTransport} supports both the "xhr_streaming" and "xhr" SockJS + * server transports. From a client perspective there is no implementation + * difference. + * + *

By default an {@code XhrTransport} will be used with "xhr_streaming" + * first and then with "xhr", if the streaming fails to connect. In some + * cases it may be useful to suppress streaming so that only "xhr" is used. + */ + boolean isXhrStreamingDisabled(); + + /** + * Execute a request to send the message to the server. + * @param transportUrl the URL for sending messages. + * @param message the message to send + */ + void executeSendRequest(URI transportUrl, TextMessage message); + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/package-info.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/package-info.java new file mode 100644 index 0000000000..6ec13fd08e --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * SockJS client implementation of + * {@link org.springframework.web.socket.client.WebSocketClient}. + */ +package org.springframework.web.socket.sockjs.client; + diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/DefaultSockJsFrameFormat.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/DefaultSockJsFrameFormat.java index 56220f5f46..fbb92f0944 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/DefaultSockJsFrameFormat.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/DefaultSockJsFrameFormat.java @@ -19,6 +19,10 @@ package org.springframework.web.socket.sockjs.frame; import org.springframework.util.Assert; /** + * A default implementation of + * {@link org.springframework.web.socket.sockjs.frame.SockJsFrameFormat} that relies + * on {@link java.lang.String#format(String, Object...)}.. + * * @author Rossen Stoyanchev * @since 4.0 */ @@ -33,14 +37,9 @@ public class DefaultSockJsFrameFormat implements SockJsFrameFormat { } - /** - * @param frame the SockJs frame. - * @return new SockJsFrame instance with the formatted content - */ @Override - public SockJsFrame format(SockJsFrame frame) { - String content = String.format(this.format, preProcessContent(frame.getContent())); - return new SockJsFrame(content); + public String format(SockJsFrame frame) { + return String.format(this.format, preProcessContent(frame.getContent())); } protected String preProcessContent(String content) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java index fbd5ca13bb..99888fcdd5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrame.java @@ -18,33 +18,69 @@ package org.springframework.web.socket.sockjs.frame; import java.nio.charset.Charset; -import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** - * Represents a SockJS frame and provides factory methods for creating SockJS frames. + * Represents a SockJS frame. Provides factory methods to create SockJS frames. * * @author Rossen Stoyanchev * @since 4.0 */ public class SockJsFrame { - private static final Charset UTF8_CHARSET = Charset.forName("UTF-8"); + public static final Charset CHARSET = Charset.forName("UTF-8"); - private static final SockJsFrame openFrame = new SockJsFrame("o"); + private static final SockJsFrame OPEN_FRAME = new SockJsFrame("o"); - private static final SockJsFrame heartbeatFrame = new SockJsFrame("h"); + private static final SockJsFrame HEARTBEAT_FRAME = new SockJsFrame("h"); - private static final SockJsFrame closeGoAwayFrame = closeFrame(3000, "Go away!"); + private static final SockJsFrame CLOSE_GO_AWAY_FRAME = closeFrame(3000, "Go away!"); - private static final SockJsFrame closeAnotherConnectionOpenFrame = closeFrame(2010, "Another connection still open"); + private static final SockJsFrame CLOSE_ANOTHER_CONNECTION_OPEN_FRAME = closeFrame(2010, "Another connection still open"); + private final SockJsFrameType type; + + private final String content; + + + /** + * Create a new instance frame with the given frame content. + * @param content the content, must be a non-empty and represent a valid SockJS frame + */ + public SockJsFrame(String content) { + StringUtils.hasText(content); + if ("o".equals(content)) { + this.type = SockJsFrameType.OPEN; + this.content = content; + } + else if ("h".equals(content)) { + this.type = SockJsFrameType.HEARTBEAT; + this.content = content; + } + else if (content.charAt(0) == 'a') { + this.type = SockJsFrameType.MESSAGE; + this.content = (content.length() > 1 ? content : "a[]"); + } + else if (content.charAt(0) == 'm') { + this.type = SockJsFrameType.MESSAGE; + this.content = (content.length() > 1 ? content : "null"); + } + else if (content.charAt(0) == 'c') { + this.type = SockJsFrameType.CLOSE; + this.content = (content.length() > 1 ? content : "c[]"); + } + else { + throw new IllegalArgumentException("Unexpected SockJS frame type in content=\"" + content + "\""); + } + } + public static SockJsFrame openFrame() { - return openFrame; + return OPEN_FRAME; } public static SockJsFrame heartbeatFrame() { - return heartbeatFrame; + return HEARTBEAT_FRAME; } public static SockJsFrame messageFrame(SockJsMessageCodec codec, String... messages) { @@ -53,11 +89,11 @@ public class SockJsFrame { } public static SockJsFrame closeFrameGoAway() { - return closeGoAwayFrame; + return CLOSE_GO_AWAY_FRAME; } public static SockJsFrame closeFrameAnotherConnectionOpen() { - return closeAnotherConnectionOpenFrame; + return CLOSE_ANOTHER_CONNECTION_OPEN_FRAME; } public static SockJsFrame closeFrame(int code, String reason) { @@ -65,23 +101,42 @@ public class SockJsFrame { } - private final String content; - - - public SockJsFrame(String content) { - Assert.notNull("Content must not be null"); - this.content = content; + /** + * Return the SockJS frame type. + */ + public SockJsFrameType getType() { + return this.type; } - + /** + * Return the SockJS frame content, never {@code null}. + */ public String getContent() { return this.content; } + /** + * Return the SockJS frame content as a byte array. + */ public byte[] getContentBytes() { - return this.content.getBytes(UTF8_CHARSET); + return this.content.getBytes(CHARSET); } + /** + * Return data contained in a SockJS "message" and "close" frames. Otherwise + * for SockJS "open" and "close" frames, which do not contain data, return + * {@code null}. + */ + public String getFrameData() { + if (SockJsFrameType.OPEN == getType() || SockJsFrameType.HEARTBEAT == getType()) { + return null; + } + else { + return getContent().substring(1); + } + } + + @Override public boolean equals(Object other) { if (this == other) { @@ -90,7 +145,7 @@ public class SockJsFrame { if (!(other instanceof SockJsFrame)) { return false; } - return this.content.equals(((SockJsFrame) other).content); + return (this.type.equals(((SockJsFrame) other).type) && this.content.equals(((SockJsFrame) other).content)); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameFormat.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameFormat.java index 858d2476ed..26376e468c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameFormat.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameFormat.java @@ -17,11 +17,22 @@ package org.springframework.web.socket.sockjs.frame; /** + * Applies a transport-specific format to the content of a SockJS frame resulting + * in a content that can be written out. Primarily for use in HTTP server-side + * transports that push data. + * + *

Formatting may vary from simply appending a new line character for XHR + * polling and streaming transports, to a jsonp-style callback function, + * surrounding script tags, and more. + * + *

For the various SockJS frame formats in use, see implementations of + * {@link org.springframework.web.socket.sockjs.transport.handler.AbstractHttpSendingTransportHandler#getFrameFormat(org.springframework.http.server.ServerHttpRequest) AbstractHttpSendingTransportHandler.getFrameFormat} + * * @author Rossen Stoyanchev * @since 4.0 */ public interface SockJsFrameFormat { - SockJsFrame format(SockJsFrame frame); + String format(SockJsFrame frame); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameType.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameType.java new file mode 100644 index 0000000000..eb6fc66460 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/frame/SockJsFrameType.java @@ -0,0 +1,29 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.frame; + +/** + * SockJS frame types. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public enum SockJsFrameType { + + OPEN, HEARTBEAT, MESSAGE, CLOSE + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index d9a06db592..dc8d020b74 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -238,7 +238,7 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem } else { response.setStatusCode(HttpStatus.NOT_FOUND); - logger.warn("Session not found"); + logger.warn("Session not found, sessionId=" + sessionId); return; } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java index 0bba697e10..0acd733767 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/AbstractHttpSendingTransportHandler.java @@ -82,12 +82,12 @@ public abstract class AbstractHttpSendingTransportHandler extends AbstractTransp } else { logger.debug("another " + getTransportType() + " connection still open: " + sockJsSession); - SockJsFrame frame = getFrameFormat(request).format(SockJsFrame.closeFrameAnotherConnectionOpen()); + String formattedFrame = getFrameFormat(request).format(SockJsFrame.closeFrameAnotherConnectionOpen()); try { - response.getBody().write(frame.getContentBytes()); + response.getBody().write(formattedFrame.getBytes(SockJsFrame.CHARSET)); } catch (IOException ex) { - throw new SockJsException("Failed to send " + frame, sockJsSession.getId(), ex); + throw new SockJsException("Failed to send " + formattedFrame, sockJsSession.getId(), ex); } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java index f433ea5f27..c2f8de4718 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractHttpSockJsSession.java @@ -30,7 +30,6 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServerHttpAsyncRequestControl; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; -import org.springframework.util.Assert; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; @@ -41,7 +40,7 @@ import org.springframework.web.socket.sockjs.frame.SockJsFrameFormat; import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; /** - * An abstract base class for use with HTTP transport based SockJS sessions. + * An abstract base class for use with HTTP transport SockJS sessions. * * @author Rossen Stoyanchev * @since 4.0 @@ -64,9 +63,12 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { private volatile ServerHttpResponse response; + private volatile SockJsFrameFormat frameFormat; + + private volatile ServerHttpAsyncRequestControl asyncRequestControl; - private volatile SockJsFrameFormat frameFormat; + private final Object responseLock = new Object(); private volatile boolean requestInitialized; @@ -124,18 +126,10 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { return this.acceptedProtocol; } - /** - * Return response for the current request, or {@code null} if between requests. - */ - protected ServerHttpResponse getResponse() { - return this.response; - } - /** * Return the SockJS buffer for messages stored transparently between polling * requests. If the polling request takes longer than 5 seconds, the session - * will be closed. - * + * is closed. * @see org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService */ protected Queue getMessageCache() { @@ -173,123 +167,104 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { return Collections.emptyList(); } + /** + * Whether this HTTP transport streams message frames vs closing the response + * after each frame written (long polling). + */ + protected abstract boolean isStreaming(); + /** - * Handle the first HTTP request, i.e. the one that starts a SockJS session. - * Write a prelude to the response (if needed), send the SockJS "open" frame - * to indicate to the client the session is opened, and invoke the - * delegate WebSocketHandler to provide it with the newly opened session. - *

- * The "xhr" and "jsonp" (polling-based) transports completes the initial request - * as soon as the open frame is sent. Following that the client should start a - * successive polling request within the same SockJS session. - *

- * The "xhr_streaming", "eventsource", and "htmlfile" transports are streaming - * based and will leave the initial request open in order to stream one or - * more messages. However, even streaming based transports eventually recycle - * the long running request, after a certain number of bytes have been streamed - * (128K by default), and allow the client to start a successive request within - * the same SockJS session. + * Handle the first request for receiving messages on a SockJS HTTP transport + * based session. + * + *

Long polling-based transports (e.g. "xhr", "jsonp") complete the request + * after writing the open frame. Streaming-based transports ("xhr_streaming", + * "eventsource", and "htmlfile") leave the response open longer for further + * streaming of message frames but will also close it eventually after some + * amount of data has been sent. * * @param request the current request * @param response the current response * @param frameFormat the transport-specific SocksJS frame format to use - * - * @see #handleSuccessiveRequest(org.springframework.http.server.ServerHttpRequest, org.springframework.http.server.ServerHttpResponse, org.springframework.web.socket.sockjs.frame.SockJsFrameFormat) */ public void handleInitialRequest(ServerHttpRequest request, ServerHttpResponse response, SockJsFrameFormat frameFormat) throws SockJsException { - initRequest(request, response, frameFormat); - this.uri = request.getURI(); this.handshakeHeaders = request.getHeaders(); this.principal = request.getPrincipal(); this.localAddress = request.getLocalAddress(); this.remoteAddress = request.getRemoteAddress(); + this.response = response; + this.frameFormat = frameFormat; + this.asyncRequestControl = request.getAsyncRequestControl(response); + try { + // Let "our" handler know before sending the open frame to the remote handler + delegateConnectionEstablished(); writePrelude(request, response); writeFrame(SockJsFrame.openFrame()); + if (isStreaming() && !isClosed()) { + startAsyncRequest(); + } } catch (Throwable ex) { tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); - throw new SockJsTransportFailureException("Failed to send \"open\" frame", getId(), ex); - } - - try { - this.requestInitialized = true; - delegateConnectionEstablished(); - } - catch (Throwable ex) { - throw new SockJsException("Unhandled exception from WebSocketHandler", getId(), ex); + throw new SockJsTransportFailureException("Failed to open session", getId(), ex); } } - private void initRequest(ServerHttpRequest request, ServerHttpResponse response, - SockJsFrameFormat frameFormat) { - - Assert.notNull(request, "Request must not be null"); - Assert.notNull(response, "Response must not be null"); - Assert.notNull(frameFormat, "SockJsFrameFormat must not be null"); - - this.response = response; - this.frameFormat = frameFormat; - this.asyncRequestControl = request.getAsyncRequestControl(response); + protected void writePrelude(ServerHttpRequest request, ServerHttpResponse response) throws IOException { } - protected void writePrelude(ServerHttpRequest request, ServerHttpResponse response) throws IOException { + private void startAsyncRequest() { + this.asyncRequestControl.start(-1); + if (this.messageCache.size() > 0) { + flushCache(); + } + else { + scheduleHeartbeat(); + } + this.requestInitialized = true; } /** - * Handle all HTTP requests part of the same SockJS session except for the very - * first, initial request. Write a prelude (if needed) and keep the request - * open and ready to send a message from the server to the client. - *

- * The "xhr" and "jsonp" (polling-based) transports completes the request when - * the next message is sent, which could be an array of messages cached during - * the time between successive requests, or it could be a heartbeat message - * sent if no other messages were sent (by default within 25 seconds). - *

- * The "xhr_streaming", "eventsource", and "htmlfile" transports are streaming - * based and will leave the request open longer in order to stream messages over - * a period of time. However, even streaming based transports eventually recycle - * the long running request, after a certain number of bytes have been streamed - * (128K by default), and allow the client to start a successive request within - * the same SockJS session. + * Handle all requests, except the first one, to receive messages on a SockJS + * HTTP transport based session. + * + *

Long polling-based transports (e.g. "xhr", "jsonp") complete the request + * after writing any buffered message frames (or the next one). Streaming-based + * transports ("xhr_streaming", "eventsource", and "htmlfile") leave the + * response open longer for further streaming of message frames but will also + * close it eventually after some amount of data has been sent. * * @param request the current request * @param response the current response * @param frameFormat the transport-specific SocksJS frame format to use - * - * @see #handleInitialRequest(org.springframework.http.server.ServerHttpRequest, org.springframework.http.server.ServerHttpResponse, org.springframework.web.socket.sockjs.frame.SockJsFrameFormat) */ public void handleSuccessiveRequest(ServerHttpRequest request, ServerHttpResponse response, SockJsFrameFormat frameFormat) throws SockJsException { - initRequest(request, response, frameFormat); - try { - writePrelude(request, response); - } - catch (Throwable ex) { - tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); - throw new SockJsTransportFailureException("Failed to send \"open\" frame", getId(), ex); + synchronized (this.responseLock) { + try { + if (isClosed()) { + response.getBody().write(SockJsFrame.closeFrameGoAway().getContentBytes()); + } + this.response = response; + this.frameFormat = frameFormat; + this.asyncRequestControl = request.getAsyncRequestControl(response); + writePrelude(request, response); + startAsyncRequest(); + } + catch (Throwable ex) { + tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); + throw new SockJsTransportFailureException("Failed to handle SockJS receive request", getId(), ex); + } } - startAsyncRequest(); } - protected void startAsyncRequest() throws SockJsException { - try { - this.asyncRequestControl.start(-1); - this.requestInitialized = true; - scheduleHeartbeat(); - tryFlushCache(); - } - catch (Throwable ex) { - tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); - throw new SockJsTransportFailureException("Failed to flush messages", getId(), ex); - } - } @Override protected final void sendMessageInternal(String message) throws SockJsTransportFailureException { @@ -297,27 +272,35 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { tryFlushCache(); } - private void tryFlushCache() throws SockJsTransportFailureException { - if (this.messageCache.isEmpty()) { - logger.trace("Nothing to flush"); - return; - } - if (logger.isTraceEnabled()) { - logger.trace(this.messageCache.size() + " message(s) to flush"); - } - if (isActive() && this.requestInitialized) { - logger.trace("Flushing messages"); - flushCache(); - } - else { + private boolean tryFlushCache() throws SockJsTransportFailureException { + synchronized (this.responseLock) { + if (this.messageCache.isEmpty()) { + logger.trace("Nothing to flush in session=" + this.getId()); + return false; + } if (logger.isTraceEnabled()) { - logger.trace("Not ready to flush"); + logger.trace(this.messageCache.size() + " message(s) to flush in session " + this.getId()); + } + if (isActive() && this.requestInitialized) { + if (logger.isTraceEnabled()) { + logger.trace("Session is active, ready to flush."); + } + cancelHeartbeat(); + flushCache(); + return true; + } + else { + if (logger.isTraceEnabled()) { + logger.trace("Session is not active, not ready to flush."); + } + return false; } } } /** - * Only called if the connection is currently active + * Called when the connection is active and ready to write to the response. + * Sub-classes should implement but never call this method directly. */ protected abstract void flushCache() throws SockJsTransportFailureException; @@ -327,35 +310,43 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession { } protected void resetRequest() { + synchronized (this.responseLock) { - this.requestInitialized = false; - updateLastActiveTime(); - - if (isActive()) { ServerHttpAsyncRequestControl control = this.asyncRequestControl; - if (control.isStarted()) { - try { - logger.debug("Completing asynchronous request"); - control.complete(); - } - catch (Throwable ex) { - logger.error("Failed to complete request: " + ex.getMessage()); + this.asyncRequestControl = null; + this.requestInitialized = false; + this.response = null; + + updateLastActiveTime(); + + if (control != null && !control.isCompleted()) { + if (control.isStarted()) { + try { + logger.debug("Completing asynchronous request"); + control.complete(); + } + catch (Throwable ex) { + logger.error("Failed to complete request: " + ex.getMessage()); + } } } } - - this.response = null; - this.asyncRequestControl = null; } @Override protected void writeFrameInternal(SockJsFrame frame) throws IOException { if (isActive()) { - frame = this.frameFormat.format(frame); + String formattedFrame = this.frameFormat.format(frame); if (logger.isTraceEnabled()) { - logger.trace("Writing " + frame); + logger.trace("Writing to HTTP response: " + formattedFrame); + } + this.response.getBody().write(formattedFrame.getBytes(SockJsFrame.CHARSET)); + if (isStreaming()) { + this.response.flush(); + } + else { + resetRequest(); } - getResponse().getBody().write(frame.getContentBytes()); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java index 215a69f6c3..0a0ec5770a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/AbstractSockJsSession.java @@ -40,6 +40,7 @@ import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.sockjs.SockJsMessageDeliveryException; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; import org.springframework.web.socket.sockjs.frame.SockJsFrame; +import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; import org.springframework.web.socket.sockjs.transport.SockJsSession; @@ -142,6 +143,10 @@ public abstract class AbstractSockJsSession implements SockJsSession { return this.id; } + protected SockJsMessageCodec getMessageCodec() { + return this.config.getMessageCodec(); + } + public SockJsServiceConfig getSockJsServiceConfig() { return this.config; } @@ -420,7 +425,9 @@ public abstract class AbstractSockJsSession implements SockJsSession { @Override public String toString() { - return "SockJS session id=" + this.id; + long currentTime = System.currentTimeMillis(); + return "SockJsSession[id=" + this.id + ", state=" + this.state + ", sinceCreated=" + + (currentTime - this.timeCreated) + ", sinceLastActive=" + (currentTime - this.timeLastActive) + "]"; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java index 52e27829a6..fa2f721722 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/PollingSockJsSession.java @@ -16,6 +16,8 @@ package org.springframework.web.socket.sockjs.transport.session; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.Queue; @@ -33,6 +35,7 @@ import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; */ public class PollingSockJsSession extends AbstractHttpSockJsSession { + public PollingSockJsSession(String sessionId, SockJsServiceConfig config, WebSocketHandler wsHandler, Map attributes) { @@ -41,22 +44,20 @@ public class PollingSockJsSession extends AbstractHttpSockJsSession { @Override - protected void flushCache() throws SockJsTransportFailureException { - cancelHeartbeat(); - Queue messageCache = getMessageCache(); - String[] messages = messageCache.toArray(new String[messageCache.size()]); - messageCache.clear(); + protected boolean isStreaming() { + return false; + } + @Override + protected void flushCache() throws SockJsTransportFailureException { + String[] messages = new String[getMessageCache().size()]; + for (int i = 0; i < messages.length; i++) { + messages[i] = getMessageCache().poll(); + } SockJsMessageCodec messageCodec = getSockJsServiceConfig().getMessageCodec(); SockJsFrame frame = SockJsFrame.messageFrame(messageCodec, messages); writeFrame(frame); } - @Override - protected void writeFrame(SockJsFrame frame) throws SockJsTransportFailureException { - super.writeFrame(frame); - resetRequest(); - } - } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java index 70facc8776..1e2d301dbe 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/StreamingSockJsSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package org.springframework.web.socket.sockjs.transport.session; -import java.io.IOException; import java.util.Map; import org.springframework.http.server.ServerHttpRequest; @@ -48,21 +47,12 @@ public class StreamingSockJsSession extends AbstractHttpSockJsSession { @Override - public void handleInitialRequest(ServerHttpRequest request, ServerHttpResponse response, - SockJsFrameFormat frameFormat) throws SockJsException { - - super.handleInitialRequest(request, response, frameFormat); - - // the WebSocketHandler delegate may have closed the session - if (!isClosed()) { - super.startAsyncRequest(); - } + protected boolean isStreaming() { + return true; } @Override protected void flushCache() throws SockJsTransportFailureException { - cancelHeartbeat(); - do { String message = getMessageCache().poll(); SockJsMessageCodec messageCodec = getSockJsServiceConfig().getMessageCodec(); @@ -79,26 +69,12 @@ public class StreamingSockJsSession extends AbstractHttpSockJsSession { logger.trace("Streamed bytes limit reached. Recycling current request"); } resetRequest(); + this.byteCount = 0; break; } } while (!getMessageCache().isEmpty()); - scheduleHeartbeat(); } - @Override - protected void resetRequest() { - super.resetRequest(); - this.byteCount = 0; - } - - @Override - protected void writeFrameInternal(SockJsFrame frame) throws IOException { - if (isActive()) { - super.writeFrameInternal(frame); - getResponse().flush(); - } - } - } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java index 51e2df3f22..d250d9e34c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/session/WebSocketServerSockJsSession.java @@ -22,6 +22,10 @@ import java.net.URI; import java.security.Principal; import java.util.List; import java.util.Map; +import java.util.Queue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import org.springframework.http.HttpHeaders; import org.springframework.util.Assert; @@ -34,7 +38,6 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.NativeWebSocketSession; import org.springframework.web.socket.sockjs.SockJsTransportFailureException; import org.springframework.web.socket.sockjs.frame.SockJsFrame; -import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; import org.springframework.web.socket.sockjs.transport.SockJsServiceConfig; /** @@ -47,6 +50,12 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession implemen private WebSocketSession webSocketSession; + private volatile boolean openFrameSent; + + private final Queue initSessionCache = new LinkedBlockingDeque(); + + private final Lock initSessionLock = new ReentrantLock(); + public WebSocketServerSockJsSession(String id, SockJsServiceConfig config, WebSocketHandler handler, Map attributes) { @@ -143,15 +152,23 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession implemen public void initializeDelegateSession(WebSocketSession session) { - this.webSocketSession = session; - try { - TextMessage message = new TextMessage(SockJsFrame.openFrame().getContent()); - this.webSocketSession.sendMessage(message); - scheduleHeartbeat(); - delegateConnectionEstablished(); - } - catch (Exception ex) { - tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); + synchronized (this.initSessionLock) { + this.webSocketSession = session; + try { + // Let "our" handler know before sending the open frame to the remote handler + delegateConnectionEstablished(); + this.webSocketSession.sendMessage(new TextMessage(SockJsFrame.openFrame().getContent())); + + // Flush any messages cached in the mean time + while (!this.initSessionCache.isEmpty()) { + writeFrame(SockJsFrame.messageFrame(getMessageCodec(), this.initSessionCache.poll())); + } + scheduleHeartbeat(); + this.openFrameSent = true; + } + catch (Exception ex) { + tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR); + } } } @@ -180,10 +197,20 @@ public class WebSocketServerSockJsSession extends AbstractSockJsSession implemen @Override public void sendMessageInternal(String message) throws SockJsTransportFailureException { + + // Open frame not sent yet? + // If in the session initialization thread, then cache, otherwise wait. + + if (!this.openFrameSent) { + synchronized (this.initSessionLock) { + if (!this.openFrameSent) { + this.initSessionCache.add(message); + return; + } + } + } cancelHeartbeat(); - SockJsMessageCodec messageCodec = getSockJsServiceConfig().getMessageCodec(); - SockJsFrame frame = SockJsFrame.messageFrame(messageCodec, message); - writeFrame(frame); + writeFrame(SockJsFrame.messageFrame(getMessageCodec(), message)); scheduleHeartbeat(); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java index 5d9af0aa31..f4b2fd3f91 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java @@ -17,12 +17,17 @@ package org.springframework.web.socket; import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.FilterHolder; import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletHolder; import org.springframework.util.SocketUtils; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.DispatcherServlet; +import javax.servlet.DispatcherType; +import javax.servlet.Filter; +import java.util.EnumSet; + /** * Jetty based {@link WebSocketTestServer}. * @@ -46,13 +51,20 @@ public class JettyWebSocketTestServer implements WebSocketTestServer { } @Override - public void deployConfig(WebApplicationContext cxt) { + public void deployConfig(WebApplicationContext cxt, Filter... filters) { ServletContextHandler contextHandler = new ServletContextHandler(); ServletHolder servletHolder = new ServletHolder(new DispatcherServlet(cxt)); contextHandler.addServlet(servletHolder, "/"); + for (Filter filter : filters) { + contextHandler.addFilter(new FilterHolder(filter), "/*", getDispatcherTypes()); + } this.jettyServer.setHandler(contextHandler); } + private EnumSet getDispatcherTypes() { + return EnumSet.of(DispatcherType.REQUEST, DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.ASYNC); + } + @Override public void undeployConfig() { // Stopping jetty will undeploy the servlet @@ -66,6 +78,7 @@ public class JettyWebSocketTestServer implements WebSocketTestServer { @Override public void stop() throws Exception { if (this.jettyServer.isRunning()) { + this.jettyServer.setStopTimeout(0); this.jettyServer.stop(); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/TomcatWebSocketTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/TomcatWebSocketTestServer.java index 2d1e855a84..caff98f687 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/TomcatWebSocketTestServer.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/TomcatWebSocketTestServer.java @@ -18,17 +18,23 @@ package org.springframework.web.socket; import java.io.File; import java.io.IOException; +import java.util.EnumSet; import org.apache.catalina.Context; import org.apache.catalina.connector.Connector; import org.apache.catalina.startup.Tomcat; import org.apache.coyote.http11.Http11NioProtocol; import org.apache.tomcat.util.descriptor.web.ApplicationListener; +import org.apache.tomcat.util.descriptor.web.FilterDef; +import org.apache.tomcat.util.descriptor.web.FilterMap; import org.apache.tomcat.websocket.server.WsContextListener; import org.springframework.util.SocketUtils; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.DispatcherServlet; +import javax.servlet.DispatcherType; +import javax.servlet.Filter; + /** * Tomcat based {@link WebSocketTestServer}. * @@ -82,11 +88,27 @@ public class TomcatWebSocketTestServer implements WebSocketTestServer { } @Override - public void deployConfig(WebApplicationContext wac) { + public void deployConfig(WebApplicationContext wac, Filter... filters) { this.context = this.tomcatServer.addContext("", System.getProperty("java.io.tmpdir")); this.context.addApplicationListener(WS_APPLICATION_LISTENER); - Tomcat.addServlet(context, "dispatcherServlet", new DispatcherServlet(wac)); + Tomcat.addServlet(this.context, "dispatcherServlet", new DispatcherServlet(wac)).setAsyncSupported(true); this.context.addServletMapping("/", "dispatcherServlet"); + for (Filter filter : filters) { + FilterDef filterDef = new FilterDef(); + filterDef.setFilterName(filter.getClass().getName()); + filterDef.setFilter(filter); + filterDef.setAsyncSupported("true"); + this.context.addFilterDef(filterDef); + FilterMap filterMap = new FilterMap(); + filterMap.setFilterName(filter.getClass().getName()); + filterMap.addURLPattern("/*"); + filterMap.setDispatcher("REQUEST,FORWARD,INCLUDE,ASYNC"); + this.context.addFilterMap(filterMap); + } + } + + private EnumSet getDispatcherTypes() { + return EnumSet.of(DispatcherType.REQUEST, DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.ASYNC); } @Override diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/UndertowTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/UndertowTestServer.java index 0980cbc76a..7d0cdd77b1 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/UndertowTestServer.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/UndertowTestServer.java @@ -16,12 +16,15 @@ package org.springframework.web.socket; +import javax.servlet.DispatcherType; +import javax.servlet.Filter; import javax.servlet.Servlet; import javax.servlet.ServletException; import io.undertow.Undertow; import io.undertow.servlet.api.DeploymentInfo; import io.undertow.servlet.api.DeploymentManager; +import io.undertow.servlet.api.FilterInfo; import io.undertow.servlet.api.InstanceFactory; import io.undertow.servlet.api.InstanceHandle; import io.undertow.websockets.jsr.WebSocketDeploymentInfo; @@ -56,7 +59,7 @@ public class UndertowTestServer implements WebSocketTestServer { } @Override - public void deployConfig(WebApplicationContext cxt) { + public void deployConfig(WebApplicationContext cxt, Filter... filters) { DispatcherServletInstanceFactory servletFactory = new DispatcherServletInstanceFactory(cxt); DeploymentInfo servletBuilder = deployment() @@ -66,6 +69,13 @@ public class UndertowTestServer implements WebSocketTestServer { .addServlet(servlet("DispatcherServlet", DispatcherServlet.class, servletFactory).addMapping("/")) .addServletContextAttribute(WebSocketDeploymentInfo.ATTRIBUTE_NAME, new WebSocketDeploymentInfo()); + for (final Filter filter : filters) { + String filterName = filter.getClass().getName(); + servletBuilder.addFilter(new FilterInfo(filterName, filter.getClass(), new FilterInstanceFactory(filter))); + for (DispatcherType type : DispatcherType.values()) { + servletBuilder.addFilterUrlMapping(filterName, "/*", type); + } + } this.manager = defaultContainer().addDeployment(servletBuilder); this.manager.deploy(); @@ -117,4 +127,25 @@ public class UndertowTestServer implements WebSocketTestServer { } } + private static class FilterInstanceFactory implements InstanceFactory { + + private final Filter filter; + + private FilterInstanceFactory(Filter filter) { + this.filter = filter; + } + + @Override + public InstanceHandle createInstance() throws InstantiationException { + return new InstanceHandle() { + @Override + public Filter getInstance() { + return filter; + } + @Override + public void release() {} + }; + } + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java index 8673abb394..04fe9651ec 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java @@ -58,24 +58,22 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTest @Override protected Class[] getAnnotatedConfigClasses() { - return new Class[] {TestWebSocketConfigurer.class}; + return new Class[] { TestConfig.class }; } @Test public void subProtocolNegotiation() throws Exception { WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); headers.setSecWebSocketProtocol("foo"); - - WebSocketSession session = this.webSocketClient.doHandshake( - new AbstractWebSocketHandler() {}, headers, new URI(getWsBaseUrl() + "/ws")).get(); - + URI url = new URI(getWsBaseUrl() + "/ws"); + WebSocketSession session = this.webSocketClient.doHandshake(new TextWebSocketHandler(), headers, url).get(); assertEquals("foo", session.getAcceptedProtocol()); } @Configuration @EnableWebSocket - static class TestWebSocketConfigurer implements WebSocketConfigurer { + static class TestConfig implements WebSocketConfigurer { @Autowired private DefaultHandshakeHandler handshakeHandler; // can't rely on classpath for server detection @@ -83,17 +81,13 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTest @Override public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { this.handshakeHandler.setSupportedProtocols("foo", "bar", "baz"); - registry.addHandler(serverHandler(), "/ws").setHandshakeHandler(this.handshakeHandler); + registry.addHandler(handler(), "/ws").setHandshakeHandler(this.handshakeHandler); } @Bean - public TestServerWebSocketHandler serverHandler() { - return new TestServerWebSocketHandler(); + public TextWebSocketHandler handler() { + return new TextWebSocketHandler(); } } - - private static class TestServerWebSocketHandler extends TextWebSocketHandler { - } - } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketTestServer.java index 0de0efa81c..aa841e61fe 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketTestServer.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketTestServer.java @@ -18,6 +18,8 @@ package org.springframework.web.socket; import org.springframework.web.context.WebApplicationContext; +import javax.servlet.Filter; + /** * Contract for a test server to use for WebSocket integration tests. * @@ -27,7 +29,7 @@ public interface WebSocketTestServer { int getPort(); - void deployConfig(WebApplicationContext cxt); + void deployConfig(WebApplicationContext cxt, Filter... filters); void undeployConfig(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index 67556f61d4..dfc73941cd 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -140,6 +140,7 @@ public class MessageBrokerBeanDefinitionParserTests { ThreadPoolTaskScheduler scheduler = (ThreadPoolTaskScheduler) defaultSockJsService.getTaskScheduler(); assertEquals(Runtime.getRuntime().availableProcessors(), scheduler.getScheduledThreadPoolExecutor().getCorePoolSize()); + assertTrue(scheduler.getScheduledThreadPoolExecutor().getRemoveOnCancelPolicy()); UserSessionRegistry userSessionRegistry = this.appContext.getBean(UserSessionRegistry.class); assertNotNull(userSessionRegistry); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationTests.java index 3a2e50bb10..4559712b28 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationTests.java @@ -57,7 +57,7 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes @Override protected Class[] getAnnotatedConfigClasses() { - return new Class[] {TestWebSocketConfigurer.class}; + return new Class[] { TestConfig.class }; } @Test @@ -65,7 +65,7 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes WebSocketSession session = this.webSocketClient.doHandshake( new AbstractWebSocketHandler() {}, getWsBaseUrl() + "/ws").get(); - TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class); + TestHandler serverHandler = this.wac.getBean(TestHandler.class); assertTrue(serverHandler.connectLatch.await(2, TimeUnit.SECONDS)); session.close(); @@ -76,7 +76,7 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes WebSocketSession session = this.webSocketClient.doHandshake( new AbstractWebSocketHandler() {}, getWsBaseUrl() + "/sockjs/websocket").get(); - TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class); + TestHandler serverHandler = this.wac.getBean(TestHandler.class); assertTrue(serverHandler.connectLatch.await(2, TimeUnit.SECONDS)); session.close(); @@ -85,7 +85,7 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes @Configuration @EnableWebSocket - static class TestWebSocketConfigurer implements WebSocketConfigurer { + static class TestConfig implements WebSocketConfigurer { @Autowired private HandshakeHandler handshakeHandler; // can't rely on classpath for server detection @@ -99,12 +99,12 @@ public class WebSocketConfigurationTests extends AbstractWebSocketIntegrationTes } @Bean - public TestWebSocketHandler serverHandler() { - return new TestWebSocketHandler(); + public TestHandler serverHandler() { + return new TestHandler(); } } - private static class TestWebSocketHandler extends AbstractWebSocketHandler { + private static class TestHandler extends AbstractWebSocketHandler { private CountDownLatch connectLatch = new CountDownLatch(1); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java index 16d64a0577..0f16354d22 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ScheduledThreadPoolExecutor; import org.junit.Before; import org.junit.Test; @@ -128,8 +129,9 @@ public class WebSocketMessageBrokerConfigurationSupportTests { ThreadPoolTaskScheduler taskScheduler = this.config.getBean("messageBrokerSockJsTaskScheduler", ThreadPoolTaskScheduler.class); - assertEquals(Runtime.getRuntime().availableProcessors(), - taskScheduler.getScheduledThreadPoolExecutor().getCorePoolSize()); + ScheduledThreadPoolExecutor executor = taskScheduler.getScheduledThreadPoolExecutor(); + assertEquals(Runtime.getRuntime().availableProcessors(), executor.getCorePoolSize()); + assertTrue(executor.getRemoveOnCancelPolicy()); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java new file mode 100644 index 0000000000..016b7ef924 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java @@ -0,0 +1,394 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.WebSocketTestServer; +import org.springframework.web.socket.config.annotation.EnableWebSocket; +import org.springframework.web.socket.config.annotation.WebSocketConfigurer; +import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; +import org.springframework.web.socket.handler.TextWebSocketHandler; +import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.RequestUpgradeStrategy; +import org.springframework.web.socket.server.support.DefaultHandshakeHandler; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.hamcrest.Matchers.*; + +/** + * Integration tests using the + * {@link org.springframework.web.socket.sockjs.client.SockJsClient}. + * against actual SockJS server endpoints. + * + * @author Rossen Stoyanchev + */ +public abstract class AbstractSockJsIntegrationTests { + + protected Log logger = LogFactory.getLog(getClass()); + + private WebSocketTestServer server; + + private AnnotationConfigWebApplicationContext wac; + + private ErrorFilter errorFilter; + + private String baseUrl; + + + @Before + public void setup() throws Exception { + this.errorFilter = new ErrorFilter(); + this.wac = new AnnotationConfigWebApplicationContext(); + this.wac.register(TestConfig.class, upgradeStrategyConfigClass()); + this.wac.refresh(); + this.server = createWebSocketTestServer(); + this.server.deployConfig(this.wac, this.errorFilter); + this.server.start(); + this.baseUrl = "http://localhost:" + this.server.getPort(); + } + + @After + public void teardown() throws Exception { + try { + this.server.undeployConfig(); + } + catch (Throwable t) { + logger.error("Failed to undeploy application config", t); + } + try { + this.server.stop(); + } + catch (Throwable t) { + logger.error("Failed to stop server", t); + } + } + + protected abstract WebSocketTestServer createWebSocketTestServer(); + + protected abstract Class upgradeStrategyConfigClass(); + + protected abstract Transport getWebSocketTransport(); + + protected abstract AbstractXhrTransport getXhrTransport(); + + protected SockJsClient createSockJsClient(Transport... transports) { + return new SockJsClient(Arrays.asList(transports)); + } + + @Test + public void echoWebSocket() throws Exception { + testEcho(100, getWebSocketTransport()); + } + + @Test + public void echoXhrStreaming() throws Exception { + testEcho(100, getXhrTransport()); + } + + @Test + public void echoXhr() throws Exception { + AbstractXhrTransport xhrTransport = getXhrTransport(); + xhrTransport.setXhrStreamingDisabled(true); + testEcho(100, xhrTransport); + } + + @Test + public void closeAfterOneMessageWebSocket() throws Exception { + testCloseAfterOneMessage(getWebSocketTransport()); + } + + @Test + public void closeAfterOneMessageXhrStreaming() throws Exception { + testCloseAfterOneMessage(getXhrTransport()); + } + + @Test + public void closeAfterOneMessageXhr() throws Exception { + AbstractXhrTransport xhrTransport = getXhrTransport(); + xhrTransport.setXhrStreamingDisabled(true); + testCloseAfterOneMessage(xhrTransport); + } + + @Test + public void infoRequestFailure() throws Exception { + TestClientHandler handler = new TestClientHandler(); + this.errorFilter.responseStatusMap.put("/info", 500); + CountDownLatch latch = new CountDownLatch(1); + createSockJsClient(getWebSocketTransport()).doHandshake(handler, this.baseUrl + "/echo").addCallback( + new ListenableFutureCallback() { + @Override + public void onSuccess(WebSocketSession result) { + + } + @Override + public void onFailure(Throwable t) { + latch.countDown(); + } + } + ); + assertTrue(latch.await(5000, TimeUnit.MILLISECONDS)); + } + + @Test + public void fallbackAfterTransportFailure() throws Exception { + this.errorFilter.responseStatusMap.put("/websocket", 200); + this.errorFilter.responseStatusMap.put("/xhr_streaming", 500); + TestClientHandler handler = new TestClientHandler(); + Transport[] transports = { getWebSocketTransport(), getXhrTransport() }; + WebSocketSession session = createSockJsClient(transports).doHandshake(handler, this.baseUrl + "/echo").get(); + assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, session.getClass()); + TextMessage message = new TextMessage("message1"); + session.sendMessage(message); + handler.awaitMessage(message, 5000); + } + + @Test(timeout = 5000) + public void fallbackAfterConnectTimeout() throws Exception { + TestClientHandler clientHandler = new TestClientHandler(); + this.errorFilter.sleepDelayMap.put("/xhr_streaming", 10000L); + this.errorFilter.responseStatusMap.put("/xhr_streaming", 503); + SockJsClient sockJsClient = createSockJsClient(getXhrTransport()); + sockJsClient.setTaskScheduler(this.wac.getBean(ThreadPoolTaskScheduler.class)); + WebSocketSession clientSession = sockJsClient.doHandshake(clientHandler, this.baseUrl + "/echo").get(); + assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, clientSession.getClass()); + TextMessage message = new TextMessage("message1"); + clientSession.sendMessage(message); + clientHandler.awaitMessage(message, 5000); + clientSession.close(); + } + + + private void testEcho(int messageCount, Transport transport) throws Exception { + List messages = new ArrayList<>(); + for (int i = 0; i < messageCount; i++) { + messages.add(new TextMessage("m" + i)); + } + TestClientHandler handler = new TestClientHandler(); + WebSocketSession session = createSockJsClient(transport).doHandshake(handler, this.baseUrl + "/echo").get(); + for (TextMessage message : messages) { + session.sendMessage(message); + } + handler.awaitMessageCount(messageCount, 5000); + for (TextMessage message : messages) { + assertTrue("Message not received: " + message, handler.receivedMessages.remove(message)); + } + assertEquals("Remaining messages: " + handler.receivedMessages, 0, handler.receivedMessages.size()); + session.close(); + } + + private void testCloseAfterOneMessage(Transport transport) throws Exception { + TestClientHandler clientHandler = new TestClientHandler(); + createSockJsClient(transport).doHandshake(clientHandler, this.baseUrl + "/test").get(); + TestServerHandler serverHandler = this.wac.getBean(TestServerHandler.class); + + assertNotNull("afterConnectionEstablished should have been called", clientHandler.session); + serverHandler.awaitSession(5000); + + TextMessage message = new TextMessage("message1"); + serverHandler.session.sendMessage(message); + clientHandler.awaitMessage(message, 5000); + + CloseStatus expected = new CloseStatus(3500, "Oops"); + serverHandler.session.close(expected); + CloseStatus actual = clientHandler.awaitCloseStatus(5000); + if (transport instanceof XhrTransport) { + assertThat(actual, Matchers.anyOf(equalTo(expected), equalTo(new CloseStatus(3000, "Go away!")))); + } + else { + assertEquals(expected, actual); + } + } + + + @Configuration + @EnableWebSocket + static class TestConfig implements WebSocketConfigurer { + + @Autowired + private RequestUpgradeStrategy upgradeStrategy; + + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + HandshakeHandler handshakeHandler = new DefaultHandshakeHandler(this.upgradeStrategy); + registry.addHandler(new EchoHandler(), "/echo").setHandshakeHandler(handshakeHandler).withSockJS(); + registry.addHandler(testServerHandler(), "/test").setHandshakeHandler(handshakeHandler).withSockJS(); + } + + @Bean + public TestServerHandler testServerHandler() { + return new TestServerHandler(); + } + } + + private static interface Condition { + boolean match(); + } + + private static void awaitEvent(Condition condition, long timeToWait, String description) { + long timeToSleep = 200; + for (int i = 0 ; i < Math.floor(timeToWait / timeToSleep); i++) { + if (condition.match()) { + return; + } + try { + Thread.sleep(timeToSleep); + } + catch (InterruptedException e) { + throw new IllegalStateException("Interrupted while waiting for " + description, e); + } + } + throw new IllegalStateException("Timed out waiting for " + description); + } + + private static class TestClientHandler extends TextWebSocketHandler { + + private final BlockingQueue receivedMessages = new LinkedBlockingQueue<>(); + + private volatile WebSocketSession session; + + private volatile CloseStatus closeStatus; + + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + this.session = session; + } + + @Override + protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { + this.receivedMessages.add(message); + } + + @Override + public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { + this.closeStatus = status; + } + + public void awaitMessageCount(final int count, long timeToWait) throws Exception { + awaitEvent(() -> receivedMessages.size() >= count, timeToWait, + count + " number of messages. Received so far: " + this.receivedMessages); + } + + public void awaitMessage(TextMessage expected, long timeToWait) throws InterruptedException { + TextMessage actual = this.receivedMessages.poll(timeToWait, TimeUnit.MILLISECONDS); + assertNotNull("Timed out waiting for [" + expected + "]", actual); + assertEquals(expected, actual); + } + + public CloseStatus awaitCloseStatus(long timeToWait) throws InterruptedException { + awaitEvent(() -> this.closeStatus != null, timeToWait, " CloseStatus"); + return this.closeStatus; + } + } + + private static class TestServerHandler extends TextWebSocketHandler { + + private WebSocketSession session; + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + this.session = session; + } + + public WebSocketSession awaitSession(long timeToWait) throws InterruptedException { + awaitEvent(() -> this.session != null, timeToWait, " session"); + return this.session; + } + } + + private static class EchoHandler extends TextWebSocketHandler { + + @Override + protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { + session.sendMessage(message); + } + } + + private static class ErrorFilter implements Filter { + + private final Map responseStatusMap = new HashMap<>(); + + private final Map sleepDelayMap = new HashMap<>(); + + @Override + public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain) throws IOException, ServletException { + for (String suffix : this.sleepDelayMap.keySet()) { + if (((HttpServletRequest) req).getRequestURI().endsWith(suffix)) { + try { + Thread.sleep(this.sleepDelayMap.get(suffix)); + break; + } + catch (InterruptedException e) { + e.printStackTrace(); + } + } + } + for (String suffix : this.responseStatusMap.keySet()) { + if (((HttpServletRequest) req).getRequestURI().endsWith(suffix)) { + ((HttpServletResponse) resp).sendError(this.responseStatusMap.get(suffix)); + return; + } + } + chain.doFilter(req, resp); + } + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + } + + @Override + public void destroy() { + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java new file mode 100644 index 0000000000..9d3f76b460 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java @@ -0,0 +1,280 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketExtension; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.util.List; + +import static org.junit.Assert.assertThat; +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.*; + +/** + * Unit tests for + * {@link org.springframework.web.socket.sockjs.client.AbstractClientSockJsSession}. + * + * @author Rossen Stoyanchev + */ +public class ClientSockJsSessionTests { + + private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec(); + + private TestClientSockJsSession session; + + private WebSocketHandler handler; + + private SettableListenableFuture connectFuture; + + @Rule + public final ExpectedException thrown = ExpectedException.none(); + + + @Before + public void setup() throws Exception { + SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); + Transport transport = mock(Transport.class); + TransportRequest request = new DefaultTransportRequest(urlInfo, null, transport, TransportType.XHR, CODEC); + this.handler = mock(WebSocketHandler.class); + this.connectFuture = new SettableListenableFuture<>(); + this.session = new TestClientSockJsSession(request, this.handler, this.connectFuture); + } + + + @Test + public void handleFrameOpen() throws Exception { + assertThat(this.session.isOpen(), is(false)); + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + assertThat(this.session.isOpen(), is(true)); + assertTrue(this.connectFuture.isDone()); + assertThat(this.connectFuture.get(), sameInstance(this.session)); + verify(this.handler).afterConnectionEstablished(this.session); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameOpenWhenStatusNotNew() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + assertThat(this.session.isOpen(), is(true)); + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(1006, "Server lost session"))); + } + + @Test + public void handleFrameOpenWithWebSocketHandlerException() throws Exception { + doThrow(new IllegalStateException("Fake error")).when(this.handler).afterConnectionEstablished(this.session); + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + assertThat(this.session.isOpen(), is(true)); + } + + @Test + public void handleFrameMessage() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent()); + verify(this.handler).afterConnectionEstablished(this.session); + verify(this.handler).handleMessage(this.session, new TextMessage("foo")); + verify(this.handler).handleMessage(this.session, new TextMessage("bar")); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameMessageWhenNotOpen() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.close(); + reset(this.handler); + this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent()); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameMessageWithBadData() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.handleFrame("a['bad data"); + assertThat(this.session.isOpen(), equalTo(false)); + assertThat(this.session.disconnectStatus, equalTo(CloseStatus.BAD_DATA)); + verify(this.handler).afterConnectionEstablished(this.session); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameMessageWithWebSocketHandlerException() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + doThrow(new IllegalStateException("Fake error")).when(this.handler).handleMessage(this.session, new TextMessage("foo")); + doThrow(new IllegalStateException("Fake error")).when(this.handler).handleMessage(this.session, new TextMessage("bar")); + this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent()); + assertThat(this.session.isOpen(), equalTo(true)); + verify(this.handler).afterConnectionEstablished(this.session); + verify(this.handler).handleMessage(this.session, new TextMessage("foo")); + verify(this.handler).handleMessage(this.session, new TextMessage("bar")); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleFrameClose() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.handleFrame(SockJsFrame.closeFrame(1007, "").getContent()); + assertThat(this.session.isOpen(), equalTo(false)); + assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(1007, ""))); + verify(this.handler).afterConnectionEstablished(this.session); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void handleTransportError() throws Exception { + final IllegalStateException ex = new IllegalStateException("Fake error"); + this.session.handleTransportError(ex); + verify(this.handler).handleTransportError(this.session, ex); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void afterTransportClosed() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.afterTransportClosed(CloseStatus.SERVER_ERROR); + assertThat(this.session.isOpen(), equalTo(false)); + verify(this.handler).afterConnectionEstablished(this.session); + verify(this.handler).afterConnectionClosed(this.session, CloseStatus.SERVER_ERROR); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void close() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.close(); + assertThat(this.session.isOpen(), equalTo(false)); + assertThat(this.session.disconnectStatus, equalTo(CloseStatus.NORMAL)); + verify(this.handler).afterConnectionEstablished(this.session); + verifyNoMoreInteractions(this.handler); + } + + @Test + public void closeWithStatus() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.close(new CloseStatus(3000, "reason")); + assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(3000, "reason"))); + } + + @Test + public void closeWithNullStatus() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("Invalid close status"); + this.session.close(null); + } + + @Test + public void closeWithStatusOutOfRange() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.thrown.expect(IllegalArgumentException.class); + this.thrown.expectMessage("Invalid close status"); + this.session.close(new CloseStatus(2999, "reason")); + } + + @Test + public void timeoutTask() { + this.session.getTimeoutTask().run(); + assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(2007, "Transport timed out"))); + } + + @Test + public void send() throws Exception { + this.session.handleFrame(SockJsFrame.openFrame().getContent()); + this.session.sendMessage(new TextMessage("foo")); + assertThat(this.session.sentMessage, equalTo(new TextMessage("[\"foo\"]"))); + } + + + private static class TestClientSockJsSession extends AbstractClientSockJsSession { + + private TextMessage sentMessage; + + private CloseStatus disconnectStatus; + + + protected TestClientSockJsSession(TransportRequest request, WebSocketHandler handler, + SettableListenableFuture connectFuture) { + super(request, handler, connectFuture); + } + + @Override + protected void sendInternal(TextMessage textMessage) throws IOException { + this.sentMessage = textMessage; + } + + @Override + protected void disconnect(CloseStatus status) throws IOException { + this.disconnectStatus = status; + } + + @Override + public InetSocketAddress getLocalAddress() { + return null; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return null; + } + + @Override + public String getAcceptedProtocol() { + return null; + } + + @Override + public void setTextMessageSizeLimit(int messageSizeLimit) { + + } + + @Override + public int getTextMessageSizeLimit() { + return 0; + } + + @Override + public void setBinaryMessageSizeLimit(int messageSizeLimit) { + + } + + @Override + public int getBinaryMessageSizeLimit() { + return 0; + } + + @Override + public List getExtensions() { + return null; + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java new file mode 100644 index 0000000000..6d1eeda631 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java @@ -0,0 +1,139 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.io.IOException; +import java.net.URI; +import java.util.Date; +import java.util.concurrent.ExecutionException; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; + +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Unit tests for {@link DefaultTransportRequest}. + * + * @author Rossen Stoyanchev + */ +public class DefaultTransportRequestTests { + + private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec(); + + + private SettableListenableFuture connectFuture; + + private ListenableFutureCallback connectCallback; + + private TestTransport webSocketTransport; + + private TestTransport xhrTransport; + + + @Rule + public final ExpectedException thrown = ExpectedException.none(); + + + @SuppressWarnings("unchecked") + @Before + public void setup() throws Exception { + this.connectCallback = mock(ListenableFutureCallback.class); + this.connectFuture = new SettableListenableFuture<>(); + this.connectFuture.addCallback(this.connectCallback); + this.webSocketTransport = new TestTransport("WebSocketTestTransport"); + this.xhrTransport = new TestTransport("XhrTestTransport"); + } + + + @Test + @SuppressWarnings("unchecked") + public void connect() throws Exception { + DefaultTransportRequest request = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET); + request.connect(null, this.connectFuture); + WebSocketSession session = mock(WebSocketSession.class); + this.webSocketTransport.getConnectCallback().onSuccess(session); + assertSame(session, this.connectFuture.get()); + } + + @Test + public void fallbackAfterTransportError() throws Exception { + DefaultTransportRequest request1 = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET); + DefaultTransportRequest request2 = createTransportRequest(this.xhrTransport, TransportType.XHR_STREAMING); + request1.setFallbackRequest(request2); + request1.connect(null, this.connectFuture); + + // Transport error => fallback + this.webSocketTransport.getConnectCallback().onFailure(new IOException("Fake exception 1")); + assertFalse(this.connectFuture.isDone()); + assertTrue(this.xhrTransport.invoked()); + + // Transport error => no more fallback + this.xhrTransport.getConnectCallback().onFailure(new IOException("Fake exception 2")); + assertTrue(this.connectFuture.isDone()); + this.thrown.expect(ExecutionException.class); + this.thrown.expectMessage("Fake exception 2"); + this.connectFuture.get(); + } + + @Test + public void fallbackAfterTimeout() throws Exception { + TaskScheduler scheduler = mock(TaskScheduler.class); + Runnable sessionCleanupTask = mock(Runnable.class); + DefaultTransportRequest request1 = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET); + DefaultTransportRequest request2 = createTransportRequest(this.xhrTransport, TransportType.XHR_STREAMING); + request1.setFallbackRequest(request2); + request1.setTimeoutScheduler(scheduler); + request1.addTimeoutTask(sessionCleanupTask); + request1.connect(null, this.connectFuture); + + assertTrue(this.webSocketTransport.invoked()); + assertFalse(this.xhrTransport.invoked()); + + // Get and invoke the scheduled timeout task + ArgumentCaptor taskCaptor = ArgumentCaptor.forClass(Runnable.class); + verify(scheduler).schedule(taskCaptor.capture(), any(Date.class)); + verifyNoMoreInteractions(scheduler); + taskCaptor.getValue().run(); + + assertTrue(this.xhrTransport.invoked()); + verify(sessionCleanupTask).run(); + } + + protected DefaultTransportRequest createTransportRequest(Transport transport, TransportType type) throws Exception { + SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); + return new DefaultTransportRequest(urlInfo, new HttpHeaders(), transport, type, CODEC); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java new file mode 100644 index 0000000000..31e1868b30 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/JettySockJsIntegrationTests.java @@ -0,0 +1,98 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.junit.After; +import org.junit.Before; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.socket.JettyWebSocketTestServer; +import org.springframework.web.socket.client.jetty.JettyWebSocketClient; +import org.springframework.web.socket.server.RequestUpgradeStrategy; +import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy; + +/** + * SockJS integration tests using Jetty for client and server. + * + * @author Rossen Stoyanchev + */ +public class JettySockJsIntegrationTests extends AbstractSockJsIntegrationTests { + + private WebSocketClient webSocketClient; + + private HttpClient httpClient; + + + @Before + public void setup() throws Exception { + super.setup(); + this.webSocketClient = new WebSocketClient(); + this.webSocketClient.start(); + this.httpClient = new HttpClient(); + this.httpClient.start(); + } + + @After + public void teardown() throws Exception { + super.teardown(); + try { + this.webSocketClient.stop(); + } + catch (Throwable ex) { + logger.error("Failed to stop Jetty WebSocketClient", ex); + } + try { + this.httpClient.stop(); + } + catch (Throwable ex) { + logger.error("Failed to stop Jetty HttpClient", ex); + } + } + + @Override + protected JettyWebSocketTestServer createWebSocketTestServer() { + return new JettyWebSocketTestServer(); + } + + @Override + protected Class upgradeStrategyConfigClass() { + return JettyTestConfig.class; + } + + @Override + protected Transport getWebSocketTransport() { + return new WebSocketTransport(new JettyWebSocketClient(this.webSocketClient)); + } + + @Override + protected AbstractXhrTransport getXhrTransport() { + return new JettyXhrTransport(this.httpClient); + } + + + @Configuration + static class JettyTestConfig { + + @Bean + public RequestUpgradeStrategy upgradeStrategy() { + return new JettyRequestUpgradeStrategy(); + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java new file mode 100644 index 0000000000..f5b9318d4a --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java @@ -0,0 +1,228 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompEncoder; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.client.RequestCallback; +import org.springframework.web.client.ResponseExtractor; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.frame.SockJsFrame; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingDeque; + +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Unit tests for {@link RestTemplateXhrTransport}. + * + * @author Rossen Stoyanchev + */ +public class RestTemplateXhrTransportTests { + + private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec(); + + private WebSocketHandler webSocketHandler; + + + @Before + public void setup() throws Exception { + this.webSocketHandler = mock(WebSocketHandler.class); + } + + + @Test + public void connectReceiveAndClose() throws Exception { + String body = "o\n" + "a[\"foo\"]\n" + "c[3000,\"Go away!\"]"; + ClientHttpResponse response = response(HttpStatus.OK, body); + connect(response); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).handleMessage(any(), eq(new TextMessage("foo"))); + verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!"))); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void connectReceiveAndCloseWithPrelude() throws Exception { + StringBuilder sb = new StringBuilder(2048); + for (int i=0; i < 2048; i++) { + sb.append('h'); + } + String body = sb.toString() + "\n" + "o\n" + "a[\"foo\"]\n" + "c[3000,\"Go away!\"]"; + ClientHttpResponse response = response(HttpStatus.OK, body); + connect(response); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).handleMessage(any(), eq(new TextMessage("foo"))); + verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!"))); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void connectReceiveAndCloseWithStompFrame() throws Exception { + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND); + accessor.setDestination("/destination"); + MessageHeaders headers = accessor.getMessageHeaders(); + Message message = MessageBuilder.createMessage("body".getBytes(Charset.forName("UTF-8")), headers); + byte[] bytes = new StompEncoder().encode(message); + TextMessage textMessage = new TextMessage(bytes); + SockJsFrame frame = SockJsFrame.messageFrame(new Jackson2SockJsMessageCodec(), textMessage.getPayload()); + + String body = "o\n" + frame.getContent() + "\n" + "c[3000,\"Go away!\"]"; + ClientHttpResponse response = response(HttpStatus.OK, body); + connect(response); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).handleMessage(any(), eq(textMessage)); + verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!"))); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void connectFailure() throws Exception { + final HttpServerErrorException expected = new HttpServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR); + RestOperations restTemplate = mock(RestOperations.class); + when(restTemplate.execute(any(), eq(HttpMethod.POST), any(), any())).thenThrow(expected); + + final CountDownLatch latch = new CountDownLatch(1); + connect(restTemplate).addCallback( + new ListenableFutureCallback() { + @Override + public void onSuccess(WebSocketSession result) { + } + @Override + public void onFailure(Throwable actual) { + if (actual == expected) { + latch.countDown(); + } + } + } + ); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void errorResponseStatus() throws Exception { + connect(response(HttpStatus.OK, "o\n"), response(HttpStatus.INTERNAL_SERVER_ERROR, "Oops")); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).handleTransportError(any(), any()); + verify(this.webSocketHandler).afterConnectionClosed(any(), any()); + verifyNoMoreInteractions(this.webSocketHandler); + } + + @Test + public void responseClosedAfterDisconnected() throws Exception { + String body = "o\n" + "c[3000,\"Go away!\"]\n" + "a[\"foo\"]\n"; + ClientHttpResponse response = response(HttpStatus.OK, body); + connect(response); + + verify(this.webSocketHandler).afterConnectionEstablished(any()); + verify(this.webSocketHandler).afterConnectionClosed(any(), any()); + verifyNoMoreInteractions(this.webSocketHandler); + verify(response).close(); + } + + private ListenableFuture connect(ClientHttpResponse... responses) throws Exception { + return connect(new TestRestTemplate(responses)); + } + + private ListenableFuture connect(RestOperations restTemplate, ClientHttpResponse... responses) + throws Exception { + + RestTemplateXhrTransport transport = new RestTemplateXhrTransport(restTemplate); + transport.setTaskExecutor(new SyncTaskExecutor()); + + SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); + HttpHeaders headers = new HttpHeaders(); + headers.add("h-foo", "h-bar"); + TransportRequest request = new DefaultTransportRequest(urlInfo, headers, transport, TransportType.XHR, CODEC); + + return transport.connect(request, this.webSocketHandler); + } + + private ClientHttpResponse response(HttpStatus status, String body) throws IOException { + ClientHttpResponse response = mock(ClientHttpResponse.class); + InputStream inputStream = getInputStream(body); + when(response.getStatusCode()).thenReturn(status); + when(response.getBody()).thenReturn(inputStream); + return response; + } + + private InputStream getInputStream(String content) { + byte[] bytes = content.getBytes(Charset.forName("UTF-8")); + return new ByteArrayInputStream(bytes); + } + + + + private static class TestRestTemplate extends RestTemplate { + + private Queue responses = new LinkedBlockingDeque<>(); + + + private TestRestTemplate(ClientHttpResponse... responses) { + this.responses.addAll(Arrays.asList(responses)); + } + + @Override + public T execute(URI url, HttpMethod method, RequestCallback callback, ResponseExtractor extractor) throws RestClientException { + try { + extractor.extractData(this.responses.remove()); + } + catch (Throwable t) { + throw new RestClientException("Failed to invoke extractor", t); + } + return null; + } + } + + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java new file mode 100644 index 0000000000..b304257803 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpStatus; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.sockjs.client.TestTransport.XhrTestTransport; + +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.*; + +/** + * Unit tests for {@link org.springframework.web.socket.sockjs.client.SockJsClient}. + * + * @author Rossen Stoyanchev + */ +public class SockJsClientTests { + + private static final String URL = "http://example.com"; + + private static final WebSocketHandler handler = mock(WebSocketHandler.class); + + + private SockJsClient sockJsClient; + + private InfoReceiver infoReceiver; + + private TestTransport webSocketTransport; + + private XhrTestTransport xhrTransport; + + private ListenableFutureCallback connectCallback; + + + @Before + @SuppressWarnings("unchecked") + public void setup() { + this.infoReceiver = mock(InfoReceiver.class); + this.webSocketTransport = new TestTransport("WebSocketTestTransport"); + this.xhrTransport = new XhrTestTransport("XhrTestTransport"); + + List transports = new ArrayList<>(); + transports.add(this.webSocketTransport); + transports.add(this.xhrTransport); + this.sockJsClient = new SockJsClient(transports); + this.sockJsClient.setInfoReceiver(this.infoReceiver); + + this.connectCallback = mock(ListenableFutureCallback.class); + } + + @Test + public void connectWebSocket() throws Exception { + setupInfoRequest(true); + this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback); + assertTrue(this.webSocketTransport.invoked()); + WebSocketSession session = mock(WebSocketSession.class); + this.webSocketTransport.getConnectCallback().onSuccess(session); + verify(this.connectCallback).onSuccess(session); + verifyNoMoreInteractions(this.connectCallback); + } + + @Test + public void connectWebSocketDisabled() throws URISyntaxException { + setupInfoRequest(false); + this.sockJsClient.doHandshake(handler, URL); + assertFalse(this.webSocketTransport.invoked()); + assertTrue(this.xhrTransport.invoked()); + assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr_streaming")); + } + + @Test + public void connectXhrStreamingDisabled() throws Exception { + setupInfoRequest(false); + this.xhrTransport.setStreamingDisabled(true); + this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback); + assertFalse(this.webSocketTransport.invoked()); + assertTrue(this.xhrTransport.invoked()); + assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr")); + } + + @Test + public void connectSockJsInfo() throws Exception { + setupInfoRequest(true); + this.sockJsClient.doHandshake(handler, URL); + verify(this.infoReceiver, times(1)).executeInfoRequest(any()); + } + + @Test + public void connectSockJsInfoCached() throws Exception { + setupInfoRequest(true); + this.sockJsClient.doHandshake(handler, URL); + this.sockJsClient.doHandshake(handler, URL); + this.sockJsClient.doHandshake(handler, URL); + verify(this.infoReceiver, times(1)).executeInfoRequest(any()); + } + + @Test + @SuppressWarnings("unchecked") + public void connectInfoRequestFailure() throws URISyntaxException { + HttpServerErrorException exception = new HttpServerErrorException(HttpStatus.SERVICE_UNAVAILABLE); + when(this.infoReceiver.executeInfoRequest(any())).thenThrow(exception); + this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback); + verify(this.connectCallback).onFailure(exception); + assertFalse(this.webSocketTransport.invoked()); + assertFalse(this.xhrTransport.invoked()); + } + + private void setupInfoRequest(boolean webSocketEnabled) { + when(this.infoReceiver.executeInfoRequest(any())).thenReturn("{\"entropy\":123," + + "\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":" + webSocketEnabled + "}"); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfoTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfoTests.java new file mode 100644 index 0000000000..462e27b783 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsUrlInfoTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Assert; +import org.junit.Test; +import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec; +import org.springframework.web.socket.sockjs.transport.TransportType; + +import java.net.URI; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Unit tests for {@code SockJsUrlInfo}. + * @author Rossen Stoyanchev + */ +public class SockJsUrlInfoTests { + + + @Test + public void serverId() throws Exception { + SockJsUrlInfo info = new SockJsUrlInfo(new URI("http://example.com")); + int serverId = Integer.valueOf(info.getServerId()); + assertTrue("Invalid serverId: " + serverId, serverId > 0 && serverId < 1000); + } + + @Test + public void sessionId() throws Exception { + SockJsUrlInfo info = new SockJsUrlInfo(new URI("http://example.com")); + assertEquals("Invalid sessionId: " + info.getSessionId(), 32, info.getSessionId().length()); + } + + @Test + public void infoUrl() throws Exception { + testInfoUrl("http", "http"); + testInfoUrl("http", "http"); + testInfoUrl("https", "https"); + testInfoUrl("https", "https"); + testInfoUrl("ws", "http"); + testInfoUrl("ws", "http"); + testInfoUrl("wss", "https"); + testInfoUrl("wss", "https"); + } + + private void testInfoUrl(String scheme, String expectedScheme) throws Exception { + SockJsUrlInfo info = new SockJsUrlInfo(new URI(scheme + "://example.com")); + Assert.assertThat(info.getInfoUrl(), is(equalTo(new URI(expectedScheme + "://example.com/info")))); + } + + @Test + public void transportUrl() throws Exception { + testTransportUrl("http", "http", TransportType.XHR_STREAMING); + testTransportUrl("http", "ws", TransportType.WEBSOCKET); + testTransportUrl("https", "https", TransportType.XHR_STREAMING); + testTransportUrl("https", "wss", TransportType.WEBSOCKET); + testTransportUrl("ws", "http", TransportType.XHR_STREAMING); + testTransportUrl("ws", "ws", TransportType.WEBSOCKET); + testTransportUrl("wss", "https", TransportType.XHR_STREAMING); + testTransportUrl("wss", "wss", TransportType.WEBSOCKET); + } + + private void testTransportUrl(String scheme, String expectedScheme, TransportType transportType) throws Exception { + SockJsUrlInfo info = new SockJsUrlInfo(new URI(scheme + "://example.com")); + String serverId = info.getServerId(); + String sessionId = info.getSessionId(); + String transport = transportType.toString().toLowerCase(); + URI expected = new URI(expectedScheme + "://example.com/" + serverId + "/" + sessionId + "/" + transport); + assertThat(info.getTransportUrl(transportType), equalTo(expected)); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java new file mode 100644 index 0000000000..f54083b2ce --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.mockito.ArgumentCaptor; +import org.springframework.util.concurrent.ListenableFuture; +import org.springframework.util.concurrent.ListenableFutureCallback; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; + +import java.net.URI; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Test SockJS Transport. + * + * @author Rossen Stoyanchev + */ +class TestTransport implements Transport { + + private final String name; + + private TransportRequest request; + + private ListenableFuture future; + + + public TestTransport(String name) { + this.name = name; + } + + public TransportRequest getRequest() { + return this.request; + } + + public boolean invoked() { + return this.future != null; + } + + @SuppressWarnings("unchecked") + public ListenableFutureCallback getConnectCallback() { + ArgumentCaptor captor = ArgumentCaptor.forClass(ListenableFutureCallback.class); + verify(this.future).addCallback(captor.capture()); + return captor.getValue(); + } + + @SuppressWarnings("unchecked") + @Override + public ListenableFuture connect(TransportRequest request, WebSocketHandler handler) { + this.request = request; + this.future = mock(ListenableFuture.class); + return this.future; + } + + @Override + public String toString() { + return "TestTransport[" + name + "]"; + } + + + static class XhrTestTransport extends TestTransport implements XhrTransport { + + private boolean streamingDisabled; + + + XhrTestTransport(String name) { + super(name); + } + + public void setStreamingDisabled(boolean streamingDisabled) { + this.streamingDisabled = streamingDisabled; + } + + @Override + public boolean isXhrStreamingDisabled() { + return this.streamingDisabled; + } + + @Override + public void executeSendRequest(URI transportUrl, TextMessage message) { + } + + @Override + public String executeInfoRequest(URI infoUrl) { + return null; + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java new file mode 100644 index 0000000000..d13bc207a1 --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java @@ -0,0 +1,155 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.client; + +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.concurrent.SettableListenableFuture; +import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketSession; + +import java.net.URI; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.notNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +/** + * Unit tests for + * {@link org.springframework.web.socket.sockjs.client.AbstractXhrTransport}. + * + * @author Rossen Stoyanchev + */ +public class XhrTransportTests { + + @Test + public void infoResponse() throws Exception { + TestXhrTransport transport = new TestXhrTransport(); + transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK); + assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"))); + } + + @Test(expected = HttpServerErrorException.class) + public void infoResponseError() throws Exception { + TestXhrTransport transport = new TestXhrTransport(); + transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST); + assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"))); + } + + @Test + public void sendMessage() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("foo", "bar"); + TestXhrTransport transport = new TestXhrTransport(); + transport.setRequestHeaders(requestHeaders); + transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT); + URI url = new URI("http://example.com"); + transport.executeSendRequest(url, new TextMessage("payload")); + assertEquals(2, transport.actualSendRequestHeaders.size()); + assertEquals("bar", transport.actualSendRequestHeaders.getFirst("foo")); + assertEquals(MediaType.APPLICATION_JSON, transport.actualSendRequestHeaders.getContentType()); + } + + @Test(expected = HttpServerErrorException.class) + public void sendMessageError() throws Exception { + TestXhrTransport transport = new TestXhrTransport(); + transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST); + URI url = new URI("http://example.com"); + transport.executeSendRequest(url, new TextMessage("payload")); + } + + @Test + public void connect() throws Exception { + HttpHeaders handshakeHeaders = new HttpHeaders(); + handshakeHeaders.setOrigin("foo"); + + TransportRequest request = mock(TransportRequest.class); + when(request.getSockJsUrlInfo()).thenReturn(new SockJsUrlInfo(new URI("http://example.com"))); + when(request.getHandshakeHeaders()).thenReturn(handshakeHeaders); + + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("foo", "bar"); + + TestXhrTransport transport = new TestXhrTransport(); + transport.setRequestHeaders(requestHeaders); + + WebSocketHandler handler = mock(WebSocketHandler.class); + transport.connect(request, handler); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(request).getSockJsUrlInfo(); + verify(request).addTimeoutTask(captor.capture()); + verify(request).getTransportUrl(); + verify(request).getHandshakeHeaders(); + verifyNoMoreInteractions(request); + + assertEquals(2, transport.actualHandshakeHeaders.size()); + assertEquals("foo", transport.actualHandshakeHeaders.getOrigin()); + assertEquals("bar", transport.actualHandshakeHeaders.getFirst("foo")); + + assertFalse(transport.actualSession.isDisconnected()); + captor.getValue().run(); + assertTrue(transport.actualSession.isDisconnected()); + } + + + private static class TestXhrTransport extends AbstractXhrTransport { + + private ResponseEntity infoResponseToReturn; + + private ResponseEntity sendMessageResponseToReturn; + + private HttpHeaders actualSendRequestHeaders; + + private HttpHeaders actualHandshakeHeaders; + + private XhrClientSockJsSession actualSession; + + + @Override + protected ResponseEntity executeInfoRequestInternal(URI infoUrl) { + return this.infoResponseToReturn; + } + + @Override + protected ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) { + this.actualSendRequestHeaders = headers; + return this.sendMessageResponseToReturn; + } + + @Override + protected void connectInternal(TransportRequest request, WebSocketHandler handler, URI receiveUrl, + HttpHeaders handshakeHeaders, XhrClientSockJsSession session, + SettableListenableFuture connectFuture) { + + this.actualHandshakeHeaders = handshakeHeaders; + this.actualSession = session; + } + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/frame/SockJsFrameTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/frame/SockJsFrameTests.java new file mode 100644 index 0000000000..d32b49d51d --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/frame/SockJsFrameTests.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.sockjs.frame; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** + * Unit tests for {@link org.springframework.web.socket.sockjs.frame.SockJsFrame}. + * + * @author Rossen Stoyanchev + * @since 4.1 + */ +public class SockJsFrameTests { + + + @Test + public void openFrame() { + SockJsFrame frame = SockJsFrame.openFrame(); + + assertEquals("o", frame.getContent()); + assertEquals(SockJsFrameType.OPEN, frame.getType()); + assertNull(frame.getFrameData()); + } + + @Test + public void heartbeatFrame() { + SockJsFrame frame = SockJsFrame.heartbeatFrame(); + + assertEquals("h", frame.getContent()); + assertEquals(SockJsFrameType.HEARTBEAT, frame.getType()); + assertNull(frame.getFrameData()); + } + + @Test + public void messageArrayFrame() { + SockJsFrame frame = SockJsFrame.messageFrame(new Jackson2SockJsMessageCodec(), "m1", "m2"); + + assertEquals("a[\"m1\",\"m2\"]", frame.getContent()); + assertEquals(SockJsFrameType.MESSAGE, frame.getType()); + assertEquals("[\"m1\",\"m2\"]", frame.getFrameData()); + } + + @Test + public void messageArrayFrameEmpty() { + SockJsFrame frame = new SockJsFrame("a"); + + assertEquals("a[]", frame.getContent()); + assertEquals(SockJsFrameType.MESSAGE, frame.getType()); + assertEquals("[]", frame.getFrameData()); + + frame = new SockJsFrame("a[]"); + + assertEquals("a[]", frame.getContent()); + assertEquals(SockJsFrameType.MESSAGE, frame.getType()); + assertEquals("[]", frame.getFrameData()); + } + + @Test + public void closeFrame() { + SockJsFrame frame = SockJsFrame.closeFrame(3000, "Go Away!"); + + assertEquals("c[3000,\"Go Away!\"]", frame.getContent()); + assertEquals(SockJsFrameType.CLOSE, frame.getType()); + assertEquals("[3000,\"Go Away!\"]", frame.getFrameData()); + } + + @Test + public void closeFrameEmpty() { + SockJsFrame frame = new SockJsFrame("c"); + + assertEquals("c[]", frame.getContent()); + assertEquals(SockJsFrameType.CLOSE, frame.getType()); + assertEquals("[]", frame.getFrameData()); + + frame = new SockJsFrame("c[]"); + + assertEquals("c[]", frame.getContent()); + assertEquals(SockJsFrameType.CLOSE, frame.getType()); + assertEquals("[]", frame.getFrameData()); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java index 3797f9024b..dcab42610e 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/HttpSendingTransportHandlerTests.java @@ -168,24 +168,24 @@ public class HttpSendingTransportHandlerTests extends AbstractHttpRequestTests SockJsFrame frame = SockJsFrame.openFrame(); SockJsFrameFormat format = new XhrPollingTransportHandler().getFrameFormat(this.request); - SockJsFrame formatted = format.format(frame); - assertEquals(frame.getContent() + "\n", formatted.getContent()); + String formatted = format.format(frame); + assertEquals(frame.getContent() + "\n", formatted); format = new XhrStreamingTransportHandler().getFrameFormat(this.request); formatted = format.format(frame); - assertEquals(frame.getContent() + "\n", formatted.getContent()); + assertEquals(frame.getContent() + "\n", formatted); format = new HtmlFileTransportHandler().getFrameFormat(this.request); formatted = format.format(frame); - assertEquals("\r\n", formatted.getContent()); + assertEquals("\r\n", formatted); format = new EventSourceTransportHandler().getFrameFormat(this.request); formatted = format.format(frame); - assertEquals("data: " + frame.getContent() + "\r\n\r\n", formatted.getContent()); + assertEquals("data: " + frame.getContent() + "\r\n\r\n", formatted); format = new JsonpPollingTransportHandler().getFrameFormat(this.request); formatted = format.format(frame); - assertEquals("callback(\"" + frame.getContent() + "\");\r\n", formatted.getContent()); + assertEquals("callback(\"" + frame.getContent() + "\");\r\n", formatted); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/HttpSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/HttpSockJsSessionTests.java index d0c96bdf4c..b2ec62daea 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/HttpSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/session/HttpSockJsSessionTests.java @@ -94,10 +94,8 @@ public class HttpSockJsSessionTests extends AbstractSockJsSessionTests expected = Arrays.asList(new TextMessage("o"), new TextMessage("a[\"go go\"]")); + assertEquals(expected, this.webSocketSession.getSentMessages()); + } + @Test public void handleMessageEmptyPayload() throws Exception { this.session.handleMessage(new TextMessage(""), this.webSocketSession); diff --git a/spring-websocket/src/test/resources/log4j.properties b/spring-websocket/src/test/resources/log4j.properties index 8db186fb4e..0b8d9ec5f6 100644 --- a/spring-websocket/src/test/resources/log4j.properties +++ b/spring-websocket/src/test/resources/log4j.properties @@ -1,9 +1,9 @@ log4j.appender.console=org.apache.log4j.ConsoleAppender log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} [%c] - %m%n +log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} [%c][%t] - %m%n log4j.rootCategory=WARN, console log4j.logger.org.springframework.web=DEBUG -log4j.logger.org.springframework.web.socket=DEBUG +log4j.logger.org.springframework.web.socket=TRACE log4j.logger.org.springframework.messaging=DEBUG