diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java index 2c6bce9cd6..2dbd815706 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/client/standard/StandardWebSocketClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,10 +21,12 @@ import java.net.InetSocketAddress; import java.net.URI; import java.net.UnknownHostException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.concurrent.Callable; + import javax.websocket.ClientEndpointConfig; import javax.websocket.ClientEndpointConfig.Configurator; import javax.websocket.ContainerProvider; @@ -49,9 +51,9 @@ import org.springframework.web.socket.adapter.standard.StandardWebSocketSession; import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter; import org.springframework.web.socket.client.AbstractWebSocketClient; + /** - * Initiates WebSocket requests to a WebSocket server programmatically - * through the standard Java WebSocket API. + * A WebSocketClient based on standard Java WebSocket API. * * @author Rossen Stoyanchev * @since 4.0 @@ -60,6 +62,8 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { private final WebSocketContainer webSocketContainer; + private final Map userProperties = new HashMap(); + private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(); @@ -84,6 +88,25 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { } + /** + * The standard Java WebSocket API allows passing "user properties" to the + * server via {@link ClientEndpointConfig#getUserProperties() userProperties}. + * Use this property to configure one or more properties to be passed on + * every handshake. + */ + public void setUserProperties(Map userProperties) { + if (userProperties != null) { + this.userProperties.putAll(userProperties); + } + } + + /** + * The configured user properties, or {@code null}. + */ + public Map getUserProperties() { + return this.userProperties; + } + /** * Set an {@link AsyncListenableTaskExecutor} to use when opening connections. * If this property is set to {@code null}, calls to any of the @@ -114,16 +137,19 @@ public class StandardWebSocketClient extends AbstractWebSocketClient { final StandardWebSocketSession session = new StandardWebSocketSession(headers, attributes, localAddress, remoteAddress); - final ClientEndpointConfig.Builder configBuilder = ClientEndpointConfig.Builder.create(); - configBuilder.configurator(new StandardWebSocketClientConfigurator(headers)); - configBuilder.preferredSubprotocols(protocols); - configBuilder.extensions(adaptExtensions(extensions)); + final ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create() + .configurator(new StandardWebSocketClientConfigurator(headers)) + .preferredSubprotocols(protocols) + .extensions(adaptExtensions(extensions)).build(); + + endpointConfig.getUserProperties().putAll(getUserProperties()); + final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session); Callable connectTask = new Callable() { @Override public WebSocketSession call() throws Exception { - webSocketContainer.connectToServer(endpoint, configBuilder.build(), uri); + webSocketContainer.connectToServer(endpoint, endpointConfig, uri); return session; } }; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/client/standard/StandardWebSocketClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/client/standard/StandardWebSocketClientTests.java index db5fb74ea1..4394f2f056 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/client/standard/StandardWebSocketClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/client/standard/StandardWebSocketClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,12 +16,15 @@ package org.springframework.web.socket.client.standard; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + import java.net.URI; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; + import javax.websocket.ClientEndpointConfig; import javax.websocket.Endpoint; import javax.websocket.WebSocketContainer; @@ -36,9 +39,6 @@ import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.AbstractWebSocketHandler; -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; - /** * Test fixture for {@link StandardWebSocketClient}. * @@ -66,8 +66,8 @@ public class StandardWebSocketClientTests { @Test - public void localAddress() throws Exception { - URI uri = new URI("ws://example.com/abc"); + public void testGetLocalAddress() throws Exception { + URI uri = new URI("ws://localhost/abc"); WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get(); assertNotNull(session.getLocalAddress()); @@ -75,8 +75,8 @@ public class StandardWebSocketClientTests { } @Test - public void localAddressWss() throws Exception { - URI uri = new URI("wss://example.com/abc"); + public void testGetLocalAddressWss() throws Exception { + URI uri = new URI("wss://localhost/abc"); WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get(); assertNotNull(session.getLocalAddress()); @@ -84,61 +84,88 @@ public class StandardWebSocketClientTests { } @Test(expected=IllegalArgumentException.class) - public void localAddressNoScheme() throws Exception { - URI uri = new URI("example.com/abc"); + public void testGetLocalAddressNoScheme() throws Exception { + URI uri = new URI("localhost/abc"); this.wsClient.doHandshake(this.wsHandler, this.headers, uri); } @Test - public void remoteAddress() throws Exception { - URI uri = new URI("wss://example.com/abc"); + public void testGetRemoteAddress() throws Exception { + URI uri = new URI("wss://localhost/abc"); WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get(); assertNotNull(session.getRemoteAddress()); - assertEquals("example.com", session.getRemoteAddress().getHostName()); + assertEquals("localhost", session.getRemoteAddress().getHostName()); assertEquals(443, session.getLocalAddress().getPort()); } @Test - public void headersWebSocketSession() throws Exception { + public void handshakeHeaders() throws Exception { - URI uri = new URI("ws://example.com/abc"); - List protocols = Arrays.asList("abc"); + URI uri = new URI("ws://localhost/abc"); + List protocols = Collections.singletonList("abc"); this.headers.setSecWebSocketProtocol(protocols); this.headers.add("foo", "bar"); WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get(); - assertEquals(Collections.singletonMap("foo", Arrays.asList("bar")), session.getHandshakeHeaders()); + assertEquals(1, session.getHandshakeHeaders().size()); + assertEquals("bar", session.getHandshakeHeaders().getFirst("foo")); } @Test - public void headersClientEndpointConfigurator() throws Exception { + public void clientEndpointConfig() throws Exception { - URI uri = new URI("ws://example.com/abc"); - List protocols = Arrays.asList("abc"); + URI uri = new URI("ws://localhost/abc"); + List protocols = Collections.singletonList("abc"); this.headers.setSecWebSocketProtocol(protocols); - this.headers.add("foo", "bar"); this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get(); - ArgumentCaptor arg1 = ArgumentCaptor.forClass(Endpoint.class); - ArgumentCaptor arg2 = ArgumentCaptor.forClass(ClientEndpointConfig.class); - ArgumentCaptor arg3 = ArgumentCaptor.forClass(URI.class); - verify(this.wsContainer).connectToServer(arg1.capture(), arg2.capture(), arg3.capture()); + ArgumentCaptor captor = ArgumentCaptor.forClass(ClientEndpointConfig.class); + verify(this.wsContainer).connectToServer(any(Endpoint.class), captor.capture(), any(URI.class)); + ClientEndpointConfig endpointConfig = captor.getValue(); - ClientEndpointConfig endpointConfig = arg2.getValue(); assertEquals(protocols, endpointConfig.getPreferredSubprotocols()); + } + + @Test + public void clientEndpointConfigWithUserProperties() throws Exception { + + Map userProperties = Collections.singletonMap("foo", "bar"); + + URI uri = new URI("ws://localhost/abc"); + this.wsClient.setUserProperties(userProperties); + this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ClientEndpointConfig.class); + verify(this.wsContainer).connectToServer(any(Endpoint.class), captor.capture(), any(URI.class)); + ClientEndpointConfig endpointConfig = captor.getValue(); + + assertEquals(userProperties, endpointConfig.getUserProperties()); + } + + @Test + public void standardWebSocketClientConfiguratorInsertsHandshakeHeaders() throws Exception { + + URI uri = new URI("ws://localhost/abc"); + this.headers.add("foo", "bar"); + + this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ClientEndpointConfig.class); + verify(this.wsContainer).connectToServer(any(Endpoint.class), captor.capture(), any(URI.class)); + ClientEndpointConfig endpointConfig = captor.getValue(); - Map> map = new HashMap<>(); - endpointConfig.getConfigurator().beforeRequest(map); - assertEquals(Collections.singletonMap("foo", Arrays.asList("bar")), map); + Map> headers = new HashMap<>(); + endpointConfig.getConfigurator().beforeRequest(headers); + assertEquals(1, headers.size()); } @Test public void taskExecutor() throws Exception { - URI uri = new URI("ws://example.com/abc"); + URI uri = new URI("ws://localhost/abc"); this.wsClient.setTaskExecutor(new SimpleAsyncTaskExecutor()); WebSocketSession session = this.wsClient.doHandshake(this.wsHandler, this.headers, uri).get();