@ -1,5 +1,5 @@
@@ -1,5 +1,5 @@
/ *
* Copyright 2002 - 2014 the original author or authors .
* Copyright 2002 - 2015 the original author or authors .
*
* Licensed under the Apache License , Version 2 . 0 ( the "License" ) ;
* you may not use this file except in compliance with the License .
@ -17,8 +17,14 @@
@@ -17,8 +17,14 @@
package org.springframework.web.socket ;
import static org.junit.Assert.* ;
import java.net.URI ;
import java.util.ArrayList ;
import java.util.Arrays ;
import java.util.List ;
import java.util.concurrent.CountDownLatch ;
import java.util.concurrent.TimeUnit ;
import org.junit.Test ;
import org.junit.runner.RunWith ;
@ -32,10 +38,10 @@ import org.springframework.web.socket.client.standard.StandardWebSocketClient;
@@ -32,10 +38,10 @@ import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.config.annotation.EnableWebSocket ;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer ;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry ;
import org.springframework.web.socket.handler.AbstractWebSocketHandler ;
import org.springframework.web.socket.handler.TextWebSocketHandler ;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler ;
import static org.junit.Assert.* ;
/ * *
* Client and server - side WebSocket integration tests .
@ -67,6 +73,24 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTest
@@ -67,6 +73,24 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTest
URI url = new URI ( getWsBaseUrl ( ) + "/ws" ) ;
WebSocketSession session = this . webSocketClient . doHandshake ( new TextWebSocketHandler ( ) , headers , url ) . get ( ) ;
assertEquals ( "foo" , session . getAcceptedProtocol ( ) ) ;
session . close ( ) ;
}
// SPR-12727
@Test
public void unsolicitedPongWithEmptyPayload ( ) throws Exception {
TestWebSocketHandler serverHandler = this . wac . getBean ( TestWebSocketHandler . class ) ;
serverHandler . setWaitMessageCount ( 1 ) ;
String url = getWsBaseUrl ( ) + "/ws" ;
WebSocketSession session = this . webSocketClient . doHandshake ( new AbstractWebSocketHandler ( ) { } , url ) . get ( ) ;
session . sendMessage ( new PongMessage ( ) ) ;
serverHandler . await ( ) ;
assertNull ( serverHandler . getTransportError ( ) ) ;
assertEquals ( 1 , serverHandler . getReceivedMessages ( ) . size ( ) ) ;
assertEquals ( PongMessage . class , serverHandler . getReceivedMessages ( ) . get ( 0 ) . getClass ( ) ) ;
}
@ -84,8 +108,51 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTest
@@ -84,8 +108,51 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTest
}
@Bean
public TextWebSocketHandler handler ( ) {
return new TextWebSocketHandler ( ) ;
public TestWebSocketHandler handler ( ) {
return new TestWebSocketHandler ( ) ;
}
}
private static class TestWebSocketHandler extends AbstractWebSocketHandler {
private List < WebSocketMessage > receivedMessages = new ArrayList < > ( ) ;
private int waitMessageCount ;
private final CountDownLatch latch = new CountDownLatch ( 1 ) ;
private Throwable transportError ;
public void setWaitMessageCount ( int waitMessageCount ) {
this . waitMessageCount = waitMessageCount ;
}
public List < WebSocketMessage > getReceivedMessages ( ) {
return this . receivedMessages ;
}
public Throwable getTransportError ( ) {
return this . transportError ;
}
@Override
public void handleMessage ( WebSocketSession session , WebSocketMessage < ? > message ) throws Exception {
this . receivedMessages . add ( message ) ;
if ( this . receivedMessages . size ( ) > = this . waitMessageCount ) {
this . latch . countDown ( ) ;
}
}
@Override
public void handleTransportError ( WebSocketSession session , Throwable exception ) throws Exception {
this . transportError = exception ;
this . latch . countDown ( ) ;
}
public void await ( ) throws InterruptedException {
this . latch . await ( 5 , TimeUnit . SECONDS ) ;
}
}