Browse Source

Add check for unused WebSocket sessions

Sessions connected to a STOMP endpoint are expected to receive some
client messages. Having received none after successfully connecting
could be an indication of proxy or network issue. This change adds
periodic checks to see if we have not received any messages on a
session which is an indication the session isn't going anywhere
most likely due to a proxy issue (or unreliable network) and close
those sessions.

Issue: SPR-11884
pull/543/merge
Rossen Stoyanchev 11 years ago
parent
commit
a3fa9c9797
  1. 136
      spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java
  2. 38
      spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java

136
spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java

@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
package org.springframework.web.socket.messaging;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
@ -24,6 +25,7 @@ import java.util.Map; @@ -24,6 +25,7 @@ import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -64,8 +66,18 @@ import org.springframework.web.socket.handler.SessionLimitExceededException; @@ -64,8 +66,18 @@ import org.springframework.web.socket.handler.SessionLimitExceededException;
public class SubProtocolWebSocketHandler implements WebSocketHandler,
SubProtocolCapable, MessageHandler, SmartLifecycle {
/**
* Sessions connected to this handler use a sub-protocol. Hence we expect to
* receive some client messages. If we don't receive any within a minute, the
* connection isn't doing well (proxy issue, slow network?) and can be closed.
* @see #checkSessions()
*/
private final int TIME_TO_FIRST_MESSAGE = 60 * 1000;
private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);
private final MessageChannel clientInboundChannel;
private final SubscribableChannel clientOutboundChannel;
@ -75,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -75,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
private SubProtocolHandler defaultProtocolHandler;
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();
private final Map<String, WebSocketSessionHolder> sessions = new ConcurrentHashMap<String, WebSocketSessionHolder>();
private int sendTimeLimit = 10 * 1000;
private int sendBufferSizeLimit = 512 * 1024;
private volatile long lastSessionCheckTime = System.currentTimeMillis();
private final ReentrantLock sessionCheckLock = new ReentrantLock();
private final Object lifecycleMonitor = new Object();
private volatile boolean running = false;
@ -214,12 +230,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -214,12 +230,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
this.clientOutboundChannel.unsubscribe(this);
// Notify sessions to stop flushing messages
for (WebSocketSession session : this.sessions.values()) {
for (WebSocketSessionHolder holder : this.sessions.values()) {
try {
session.close(CloseStatus.GOING_AWAY);
holder.getSession().close(CloseStatus.GOING_AWAY);
}
catch (Throwable t) {
logger.error("Failed to close session id '" + session.getId() + "': " + t.getMessage());
logger.error("Failed to close '" + holder.getSession() + "': " + t.getMessage());
}
}
}
@ -235,15 +251,11 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -235,15 +251,11 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
session = new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit());
this.sessions.put(session.getId(), session);
this.sessions.put(session.getId(), new WebSocketSessionHolder(session));
if (logger.isDebugEnabled()) {
logger.debug("Started WebSocket session=" + session.getId() +
", number of sessions=" + this.sessions.size());
logger.debug("Started session " + session.getId() + ", number of sessions=" + this.sessions.size());
}
findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel);
}
@ -283,41 +295,49 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -283,41 +295,49 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
findProtocolHandler(session).handleMessageFromClient(session, message, this.clientInboundChannel);
SubProtocolHandler protocolHandler = findProtocolHandler(session);
protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel);
WebSocketSessionHolder holder = this.sessions.get(session.getId());
if (holder != null) {
holder.setHasHandledMessages();
}
else {
// Should never happen
throw new IllegalStateException("Session not found: " + session);
}
checkSessions();
}
@Override
public void handleMessage(Message<?> message) throws MessagingException {
String sessionId = resolveSessionId(message);
if (sessionId == null) {
logger.error("sessionId not found in message " + message);
return;
}
WebSocketSession session = this.sessions.get(sessionId);
if (session == null) {
WebSocketSessionHolder holder = this.sessions.get(sessionId);
if (holder == null) {
logger.error("Session not found for session with id '" + sessionId + "', ignoring message " + message);
return;
}
WebSocketSession session = holder.getSession();
try {
findProtocolHandler(session).handleMessageToClient(session, message);
}
catch (SessionLimitExceededException ex) {
try {
logger.error("Terminating session id '" + sessionId + "'", ex);
logger.error("Terminating '" + session + "'", ex);
// Session may be unresponsive so clear first
clearSession(session, ex.getStatus());
session.close(ex.getStatus());
}
catch (Exception secondException) {
logger.error("Exception terminating session id '" + sessionId + "'", secondException);
logger.error("Exception terminating '" + sessionId + "'", secondException);
}
}
catch (Exception e) {
logger.error("Failed to send message to client " + message, e);
logger.error("Failed to send message to client " + message + " in " + session, e);
}
}
@ -337,6 +357,43 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -337,6 +357,43 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
return null;
}
/**
* Periodically check sessions to ensure they have received at least one
* message or otherwise close them.
*/
private void checkSessions() throws IOException {
long currentTime = System.currentTimeMillis();
if (!isRunning() && currentTime - this.lastSessionCheckTime < TIME_TO_FIRST_MESSAGE) {
return;
}
try {
if (this.sessionCheckLock.tryLock()) {
for (WebSocketSessionHolder holder : this.sessions.values()) {
if (holder.hasHandledMessages()) {
continue;
}
long timeSinceCreated = currentTime - holder.getCreateTime();
if (holder.hasHandledMessages() || timeSinceCreated < TIME_TO_FIRST_MESSAGE) {
continue;
}
WebSocketSession session = holder.getSession();
if (logger.isErrorEnabled()) {
logger.error("No messages received after " + timeSinceCreated + " ms. Closing " + holder);
}
try {
session.close(CloseStatus.PROTOCOL_ERROR);
}
catch (Throwable t) {
logger.error("Failed to close " + session, t);
}
}
}
}
finally {
this.sessionCheckLock.unlock();
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
}
@ -356,4 +413,45 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @@ -356,4 +413,45 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
return false;
}
private static class WebSocketSessionHolder {
private final WebSocketSession session;
private final long createTime = System.currentTimeMillis();
private volatile boolean handledMessages;
private WebSocketSessionHolder(WebSocketSession session) {
this.session = session;
}
public WebSocketSession getSession() {
return this.session;
}
public long getCreateTime() {
return this.createTime;
}
public void setHasHandledMessages() {
this.handledMessages = true;
}
public boolean hasHandledMessages() {
return this.handledMessages;
}
@Override
public String toString() {
if (this.session instanceof ConcurrentWebSocketSessionDecorator) {
return ((ConcurrentWebSocketSessionDecorator) this.session).getLastSession().toString();
}
else {
return this.session.toString();
}
}
}
}

38
spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java

@ -17,16 +17,24 @@ @@ -17,16 +17,24 @@
package org.springframework.web.socket.messaging;
import java.util.Arrays;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.TestWebSocketSession;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.*;
/**
@ -56,11 +64,9 @@ public class SubProtocolWebSocketHandlerTests { @@ -56,11 +64,9 @@ public class SubProtocolWebSocketHandlerTests {
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
this.webSocketHandler = new SubProtocolWebSocketHandler(this.inClientChannel, this.outClientChannel);
when(stompHandler.getSupportedProtocols()).thenReturn(Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp"));
when(mqttHandler.getSupportedProtocols()).thenReturn(Arrays.asList("MQTT"));
this.session = new TestWebSocketSession();
this.session.setId("1");
}
@ -140,4 +146,32 @@ public class SubProtocolWebSocketHandlerTests { @@ -140,4 +146,32 @@ public class SubProtocolWebSocketHandlerTests {
this.webSocketHandler.afterConnectionEstablished(session);
}
@Test
public void checkSession() throws Exception {
TestWebSocketSession session1 = new TestWebSocketSession("id1");
TestWebSocketSession session2 = new TestWebSocketSession("id2");
session1.setAcceptedProtocol("v12.stomp");
session2.setAcceptedProtocol("v12.stomp");
this.webSocketHandler.setProtocolHandlers(Arrays.asList(this.stompHandler));
this.webSocketHandler.afterConnectionEstablished(session1);
this.webSocketHandler.afterConnectionEstablished(session2);
session1.setOpen(true);
session2.setOpen(true);
long sixtyOneSecondsAgo = System.currentTimeMillis() - 61 * 1000;
new DirectFieldAccessor(this.webSocketHandler).setPropertyValue("lastSessionCheckTime", sixtyOneSecondsAgo);
Map<String, ?> sessions = (Map<String, ?>) new DirectFieldAccessor(this.webSocketHandler).getPropertyValue("sessions");
new DirectFieldAccessor(sessions.get("id1")).setPropertyValue("createTime", sixtyOneSecondsAgo);
new DirectFieldAccessor(sessions.get("id2")).setPropertyValue("createTime", sixtyOneSecondsAgo);
this.webSocketHandler.handleMessage(session1, new TextMessage("foo"));
assertTrue(session1.isOpen());
assertFalse(session2.isOpen());
assertNull(session1.getCloseStatus());
assertEquals(CloseStatus.PROTOCOL_ERROR, session2.getCloseStatus());
}
}

Loading…
Cancel
Save