diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index 572f5ffc29..a5ae409c47 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -16,6 +16,7 @@ package org.springframework.web.socket.messaging; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -24,6 +25,7 @@ import java.util.Map; import java.util.Set; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -64,8 +66,18 @@ import org.springframework.web.socket.handler.SessionLimitExceededException; public class SubProtocolWebSocketHandler implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle { + /** + * Sessions connected to this handler use a sub-protocol. Hence we expect to + * receive some client messages. If we don't receive any within a minute, the + * connection isn't doing well (proxy issue, slow network?) and can be closed. + * @see #checkSessions() + */ + private final int TIME_TO_FIRST_MESSAGE = 60 * 1000; + + private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class); + private final MessageChannel clientInboundChannel; private final SubscribableChannel clientOutboundChannel; @@ -75,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, private SubProtocolHandler defaultProtocolHandler; - private final Map sessions = new ConcurrentHashMap(); + private final Map sessions = new ConcurrentHashMap(); private int sendTimeLimit = 10 * 1000; private int sendBufferSizeLimit = 512 * 1024; + private volatile long lastSessionCheckTime = System.currentTimeMillis(); + + private final ReentrantLock sessionCheckLock = new ReentrantLock(); + private final Object lifecycleMonitor = new Object(); private volatile boolean running = false; @@ -214,12 +230,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, this.clientOutboundChannel.unsubscribe(this); // Notify sessions to stop flushing messages - for (WebSocketSession session : this.sessions.values()) { + for (WebSocketSessionHolder holder : this.sessions.values()) { try { - session.close(CloseStatus.GOING_AWAY); + holder.getSession().close(CloseStatus.GOING_AWAY); } catch (Throwable t) { - logger.error("Failed to close session id '" + session.getId() + "': " + t.getMessage()); + logger.error("Failed to close '" + holder.getSession() + "': " + t.getMessage()); } } } @@ -235,15 +251,11 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { - session = new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit()); - - this.sessions.put(session.getId(), session); + this.sessions.put(session.getId(), new WebSocketSessionHolder(session)); if (logger.isDebugEnabled()) { - logger.debug("Started WebSocket session=" + session.getId() + - ", number of sessions=" + this.sessions.size()); + logger.debug("Started session " + session.getId() + ", number of sessions=" + this.sessions.size()); } - findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel); } @@ -283,41 +295,49 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @Override public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { - findProtocolHandler(session).handleMessageFromClient(session, message, this.clientInboundChannel); + SubProtocolHandler protocolHandler = findProtocolHandler(session); + protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel); + WebSocketSessionHolder holder = this.sessions.get(session.getId()); + if (holder != null) { + holder.setHasHandledMessages(); + } + else { + // Should never happen + throw new IllegalStateException("Session not found: " + session); + } + checkSessions(); } @Override public void handleMessage(Message message) throws MessagingException { - String sessionId = resolveSessionId(message); if (sessionId == null) { logger.error("sessionId not found in message " + message); return; } - - WebSocketSession session = this.sessions.get(sessionId); - if (session == null) { + WebSocketSessionHolder holder = this.sessions.get(sessionId); + if (holder == null) { logger.error("Session not found for session with id '" + sessionId + "', ignoring message " + message); return; } - + WebSocketSession session = holder.getSession(); try { findProtocolHandler(session).handleMessageToClient(session, message); } catch (SessionLimitExceededException ex) { try { - logger.error("Terminating session id '" + sessionId + "'", ex); + logger.error("Terminating '" + session + "'", ex); // Session may be unresponsive so clear first clearSession(session, ex.getStatus()); session.close(ex.getStatus()); } catch (Exception secondException) { - logger.error("Exception terminating session id '" + sessionId + "'", secondException); + logger.error("Exception terminating '" + sessionId + "'", secondException); } } catch (Exception e) { - logger.error("Failed to send message to client " + message, e); + logger.error("Failed to send message to client " + message + " in " + session, e); } } @@ -337,6 +357,43 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, return null; } + /** + * Periodically check sessions to ensure they have received at least one + * message or otherwise close them. + */ + private void checkSessions() throws IOException { + long currentTime = System.currentTimeMillis(); + if (!isRunning() && currentTime - this.lastSessionCheckTime < TIME_TO_FIRST_MESSAGE) { + return; + } + try { + if (this.sessionCheckLock.tryLock()) { + for (WebSocketSessionHolder holder : this.sessions.values()) { + if (holder.hasHandledMessages()) { + continue; + } + long timeSinceCreated = currentTime - holder.getCreateTime(); + if (holder.hasHandledMessages() || timeSinceCreated < TIME_TO_FIRST_MESSAGE) { + continue; + } + WebSocketSession session = holder.getSession(); + if (logger.isErrorEnabled()) { + logger.error("No messages received after " + timeSinceCreated + " ms. Closing " + holder); + } + try { + session.close(CloseStatus.PROTOCOL_ERROR); + } + catch (Throwable t) { + logger.error("Failed to close " + session, t); + } + } + } + } + finally { + this.sessionCheckLock.unlock(); + } + } + @Override public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { } @@ -356,4 +413,45 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, return false; } + + private static class WebSocketSessionHolder { + + private final WebSocketSession session; + + private final long createTime = System.currentTimeMillis(); + + private volatile boolean handledMessages; + + + private WebSocketSessionHolder(WebSocketSession session) { + this.session = session; + } + + public WebSocketSession getSession() { + return this.session; + } + + public long getCreateTime() { + return this.createTime; + } + + public void setHasHandledMessages() { + this.handledMessages = true; + } + + public boolean hasHandledMessages() { + return this.handledMessages; + } + + @Override + public String toString() { + if (this.session instanceof ConcurrentWebSocketSessionDecorator) { + return ((ConcurrentWebSocketSessionDecorator) this.session).getLastSession().toString(); + } + else { + return this.session.toString(); + } + } + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java index 31cf82ac76..3ef5d87f0b 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java @@ -17,16 +17,24 @@ package org.springframework.web.socket.messaging; import java.util.Arrays; +import java.util.Map; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.springframework.beans.DirectFieldAccessor; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.SubscribableChannel; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; import org.springframework.web.socket.handler.TestWebSocketSession; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.*; /** @@ -56,11 +64,9 @@ public class SubProtocolWebSocketHandlerTests { @Before public void setup() { MockitoAnnotations.initMocks(this); - this.webSocketHandler = new SubProtocolWebSocketHandler(this.inClientChannel, this.outClientChannel); when(stompHandler.getSupportedProtocols()).thenReturn(Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp")); when(mqttHandler.getSupportedProtocols()).thenReturn(Arrays.asList("MQTT")); - this.session = new TestWebSocketSession(); this.session.setId("1"); } @@ -140,4 +146,32 @@ public class SubProtocolWebSocketHandlerTests { this.webSocketHandler.afterConnectionEstablished(session); } + @Test + public void checkSession() throws Exception { + TestWebSocketSession session1 = new TestWebSocketSession("id1"); + TestWebSocketSession session2 = new TestWebSocketSession("id2"); + session1.setAcceptedProtocol("v12.stomp"); + session2.setAcceptedProtocol("v12.stomp"); + + this.webSocketHandler.setProtocolHandlers(Arrays.asList(this.stompHandler)); + this.webSocketHandler.afterConnectionEstablished(session1); + this.webSocketHandler.afterConnectionEstablished(session2); + session1.setOpen(true); + session2.setOpen(true); + + long sixtyOneSecondsAgo = System.currentTimeMillis() - 61 * 1000; + new DirectFieldAccessor(this.webSocketHandler).setPropertyValue("lastSessionCheckTime", sixtyOneSecondsAgo); + Map sessions = (Map) new DirectFieldAccessor(this.webSocketHandler).getPropertyValue("sessions"); + new DirectFieldAccessor(sessions.get("id1")).setPropertyValue("createTime", sixtyOneSecondsAgo); + new DirectFieldAccessor(sessions.get("id2")).setPropertyValue("createTime", sixtyOneSecondsAgo); + + this.webSocketHandler.handleMessage(session1, new TextMessage("foo")); + + assertTrue(session1.isOpen()); + assertFalse(session2.isOpen()); + assertNull(session1.getCloseStatus()); + assertEquals(CloseStatus.PROTOCOL_ERROR, session2.getCloseStatus()); + } + + }