@ -16,6 +16,7 @@
@@ -16,6 +16,7 @@
package org.springframework.web.socket.messaging ;
import java.io.IOException ;
import java.util.ArrayList ;
import java.util.Arrays ;
import java.util.HashSet ;
@ -24,6 +25,7 @@ import java.util.Map;
@@ -24,6 +25,7 @@ import java.util.Map;
import java.util.Set ;
import java.util.TreeMap ;
import java.util.concurrent.ConcurrentHashMap ;
import java.util.concurrent.locks.ReentrantLock ;
import org.apache.commons.logging.Log ;
import org.apache.commons.logging.LogFactory ;
@ -64,8 +66,18 @@ import org.springframework.web.socket.handler.SessionLimitExceededException;
@@ -64,8 +66,18 @@ import org.springframework.web.socket.handler.SessionLimitExceededException;
public class SubProtocolWebSocketHandler implements WebSocketHandler ,
SubProtocolCapable , MessageHandler , SmartLifecycle {
/ * *
* Sessions connected to this handler use a sub - protocol . Hence we expect to
* receive some client messages . If we don ' t receive any within a minute , the
* connection isn ' t doing well ( proxy issue , slow network ? ) and can be closed .
* @see # checkSessions ( )
* /
private final int TIME_TO_FIRST_MESSAGE = 60 * 1000 ;
private final Log logger = LogFactory . getLog ( SubProtocolWebSocketHandler . class ) ;
private final MessageChannel clientInboundChannel ;
private final SubscribableChannel clientOutboundChannel ;
@ -75,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
@@ -75,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
private SubProtocolHandler defaultProtocolHandler ;
private final Map < String , WebSocketSession > sessions = new ConcurrentHashMap < String , WebSocketSession > ( ) ;
private final Map < String , WebSocketSessionHolder > sessions = new ConcurrentHashMap < String , WebSocketSessionHolder > ( ) ;
private int sendTimeLimit = 10 * 1000 ;
private int sendBufferSizeLimit = 512 * 1024 ;
private volatile long lastSessionCheckTime = System . currentTimeMillis ( ) ;
private final ReentrantLock sessionCheckLock = new ReentrantLock ( ) ;
private final Object lifecycleMonitor = new Object ( ) ;
private volatile boolean running = false ;
@ -214,12 +230,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
@@ -214,12 +230,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
this . clientOutboundChannel . unsubscribe ( this ) ;
// Notify sessions to stop flushing messages
for ( WebSocketSession session : this . sessions . values ( ) ) {
for ( WebSocketSessionHolder holder : this . sessions . values ( ) ) {
try {
s ession. close ( CloseStatus . GOING_AWAY ) ;
holder . getS ession( ) . close ( CloseStatus . GOING_AWAY ) ;
}
catch ( Throwable t ) {
logger . error ( "Failed to close session id '" + session . getId ( ) + "': " + t . getMessage ( ) ) ;
logger . error ( "Failed to close '" + holder . getSession ( ) + "': " + t . getMessage ( ) ) ;
}
}
}
@ -235,15 +251,11 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
@@ -235,15 +251,11 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
@Override
public void afterConnectionEstablished ( WebSocketSession session ) throws Exception {
session = new ConcurrentWebSocketSessionDecorator ( session , getSendTimeLimit ( ) , getSendBufferSizeLimit ( ) ) ;
this . sessions . put ( session . getId ( ) , session ) ;
this . sessions . put ( session . getId ( ) , new WebSocketSessionHolder ( session ) ) ;
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Started WebSocket session=" + session . getId ( ) +
", number of sessions=" + this . sessions . size ( ) ) ;
logger . debug ( "Started session " + session . getId ( ) + ", number of sessions=" + this . sessions . size ( ) ) ;
}
findProtocolHandler ( session ) . afterSessionStarted ( session , this . clientInboundChannel ) ;
}
@ -283,41 +295,49 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
@@ -283,41 +295,49 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
@Override
public void handleMessage ( WebSocketSession session , WebSocketMessage < ? > message ) throws Exception {
findProtocolHandler ( session ) . handleMessageFromClient ( session , message , this . clientInboundChannel ) ;
SubProtocolHandler protocolHandler = findProtocolHandler ( session ) ;
protocolHandler . handleMessageFromClient ( session , message , this . clientInboundChannel ) ;
WebSocketSessionHolder holder = this . sessions . get ( session . getId ( ) ) ;
if ( holder ! = null ) {
holder . setHasHandledMessages ( ) ;
}
else {
// Should never happen
throw new IllegalStateException ( "Session not found: " + session ) ;
}
checkSessions ( ) ;
}
@Override
public void handleMessage ( Message < ? > message ) throws MessagingException {
String sessionId = resolveSessionId ( message ) ;
if ( sessionId = = null ) {
logger . error ( "sessionId not found in message " + message ) ;
return ;
}
WebSocketSession session = this . sessions . get ( sessionId ) ;
if ( session = = null ) {
WebSocketSessionHolder holder = this . sessions . get ( sessionId ) ;
if ( holder = = null ) {
logger . error ( "Session not found for session with id '" + sessionId + "', ignoring message " + message ) ;
return ;
}
WebSocketSession session = holder . getSession ( ) ;
try {
findProtocolHandler ( session ) . handleMessageToClient ( session , message ) ;
}
catch ( SessionLimitExceededException ex ) {
try {
logger . error ( "Terminating session id '" + sessionId + "'" , ex ) ;
logger . error ( "Terminating '" + session + "'" , ex ) ;
// Session may be unresponsive so clear first
clearSession ( session , ex . getStatus ( ) ) ;
session . close ( ex . getStatus ( ) ) ;
}
catch ( Exception secondException ) {
logger . error ( "Exception terminating session id '" + sessionId + "'" , secondException ) ;
logger . error ( "Exception terminating '" + sessionId + "'" , secondException ) ;
}
}
catch ( Exception e ) {
logger . error ( "Failed to send message to client " + message , e ) ;
logger . error ( "Failed to send message to client " + message + " in " + session , e ) ;
}
}
@ -337,6 +357,43 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
@@ -337,6 +357,43 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
return null ;
}
/ * *
* Periodically check sessions to ensure they have received at least one
* message or otherwise close them .
* /
private void checkSessions ( ) throws IOException {
long currentTime = System . currentTimeMillis ( ) ;
if ( ! isRunning ( ) & & currentTime - this . lastSessionCheckTime < TIME_TO_FIRST_MESSAGE ) {
return ;
}
try {
if ( this . sessionCheckLock . tryLock ( ) ) {
for ( WebSocketSessionHolder holder : this . sessions . values ( ) ) {
if ( holder . hasHandledMessages ( ) ) {
continue ;
}
long timeSinceCreated = currentTime - holder . getCreateTime ( ) ;
if ( holder . hasHandledMessages ( ) | | timeSinceCreated < TIME_TO_FIRST_MESSAGE ) {
continue ;
}
WebSocketSession session = holder . getSession ( ) ;
if ( logger . isErrorEnabled ( ) ) {
logger . error ( "No messages received after " + timeSinceCreated + " ms. Closing " + holder ) ;
}
try {
session . close ( CloseStatus . PROTOCOL_ERROR ) ;
}
catch ( Throwable t ) {
logger . error ( "Failed to close " + session , t ) ;
}
}
}
}
finally {
this . sessionCheckLock . unlock ( ) ;
}
}
@Override
public void handleTransportError ( WebSocketSession session , Throwable exception ) throws Exception {
}
@ -356,4 +413,45 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
@@ -356,4 +413,45 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
return false ;
}
private static class WebSocketSessionHolder {
private final WebSocketSession session ;
private final long createTime = System . currentTimeMillis ( ) ;
private volatile boolean handledMessages ;
private WebSocketSessionHolder ( WebSocketSession session ) {
this . session = session ;
}
public WebSocketSession getSession ( ) {
return this . session ;
}
public long getCreateTime ( ) {
return this . createTime ;
}
public void setHasHandledMessages ( ) {
this . handledMessages = true ;
}
public boolean hasHandledMessages ( ) {
return this . handledMessages ;
}
@Override
public String toString ( ) {
if ( this . session instanceof ConcurrentWebSocketSessionDecorator ) {
return ( ( ConcurrentWebSocketSessionDecorator ) this . session ) . getLastSession ( ) . toString ( ) ;
}
else {
return this . session . toString ( ) ;
}
}
}
}