@ -33,6 +33,7 @@ import javax.websocket.WebSocketContainer;
@@ -33,6 +33,7 @@ import javax.websocket.WebSocketContainer;
import org.apache.commons.logging.Log ;
import org.apache.commons.logging.LogFactory ;
import org.springframework.http.HttpHeaders ;
import org.springframework.util.Assert ;
import org.springframework.web.socket.WebSocketHandler ;
import org.springframework.web.socket.WebSocketSession ;
import org.springframework.web.socket.adapter.StandardEndpointAdapter ;
@ -53,49 +54,57 @@ public class StandardWebSocketClient implements WebSocketClient {
@@ -53,49 +54,57 @@ public class StandardWebSocketClient implements WebSocketClient {
private static final Log logger = LogFactory . getLog ( StandardWebSocketClient . class ) ;
private WebSocketContainer webSocketContainer ;
private final WebSocketContainer webSocketContainer ;
public WebSocketContainer getWebSocketContainer ( ) {
if ( this . webSocketContainer = = null ) {
this . webSocketContainer = ContainerProvider . getWebSocketContainer ( ) ;
}
return this . webSocketContainer ;
public StandardWebSocketClient ( ) {
this . webSocketContainer = ContainerProvider . getWebSocketContainer ( ) ;
}
public void setWebSocketContainer ( WebSocketContainer container ) {
this . webSocketContainer = container ;
public StandardWebSocketClient ( WebSocketContainer webSocketContainer ) {
Assert . notNull ( webSocketContainer , "webSocketContainer is required" ) ;
this . webSocketContainer = webSocketContainer ;
}
@Override
public WebSocketSession doHandshake ( WebSocketHandler webSocketHandler , String uriTemplate , Object . . . uriVariables )
throws WebSocketConnectFailureException {
Assert . notNull ( uriTemplate , "uriTemplate is required" ) ;
UriComponents uriComponents = UriComponentsBuilder . fromUriString ( uriTemplate ) . buildAndExpand ( uriVariables ) . encode ( ) ;
return doHandshake ( webSocketHandler , null , uriComponents ) ;
return doHandshake ( webSocketHandler , null , uriComponents . toUri ( ) ) ;
}
@Override
public WebSocketSession doHandshake ( WebSocketHandler webSocketHandler , HttpHeaders httpHeaders , URI uri )
throws WebSocketConnectFailureException {
Assert . notNull ( webSocketHandler , "webSocketHandler is required" ) ;
Assert . notNull ( uri , "uri is required" ) ;
httpHeaders = ( httpHeaders ! = null ) ? httpHeaders : new HttpHeaders ( ) ;
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Connecting to " + uri ) ;
}
StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter ( ) ;
session . setUri ( uri ) ;
session . setRemoteHostName ( uri . getHost ( ) ) ;
Endpoint endpoint = new StandardEndpointAdapter ( webSocketHandler , session ) ;
ClientEndpointConfig . Builder configBuidler = ClientEndpointConfig . Builder . create ( ) ;
if ( httpHeaders ! = null ) {
List < String > protocols = httpHeaders . getSecWebSocketProtocol ( ) ;
if ( ! protocols . isEmpty ( ) ) {
configBuidler . preferredSubprotocols ( protocols ) ;
}
configBuidler . configurator ( new StandardWebSocketClientConfigurator ( httpHeaders ) ) ;
configBuidler . configurator ( new StandardWebSocketClientConfigurator ( httpHeaders ) ) ;
List < String > protocols = httpHeaders . getSecWebSocketProtocol ( ) ;
if ( ! protocols . isEmpty ( ) ) {
configBuidler . preferredSubprotocols ( protocols ) ;
}
try {
// TODO: do not block
Endpoint endpoint = new StandardEndpointAdapter ( webSocketHandler , session ) ;
this . webSocketContainer . connectToServer ( endpoint , configBuidler . build ( ) , uri ) ;
return session ;
}
catch ( Exception e ) {
@ -128,14 +137,14 @@ public class StandardWebSocketClient implements WebSocketClient {
@@ -128,14 +137,14 @@ public class StandardWebSocketClient implements WebSocketClient {
headers . put ( headerName , value ) ;
}
}
if ( logger . isTrace Enabled ( ) ) {
logger . trace ( "Handshake request headers: " + headers ) ;
if ( logger . isDebug Enabled ( ) ) {
logger . debug ( "Handshake request headers: " + headers ) ;
}
}
@Override
public void afterResponse ( HandshakeResponse handshakeResponse ) {
if ( logger . isTrace Enabled ( ) ) {
logger . trace ( "Handshake response headers: " + handshakeResponse . getHeaders ( ) ) ;
if ( logger . isDebug Enabled ( ) ) {
logger . debug ( "Handshake response headers: " + handshakeResponse . getHeaders ( ) ) ;
}
}
}