diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java index f52bae7a5f..15dd486a75 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java @@ -169,8 +169,31 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat @Override public void close(CloseStatus status) throws IOException { - this.closeInProgress = true; - super.close(status); + this.closeLock.lock(); + try { + if (this.closeInProgress) { + return; + } + if (!CloseStatus.SESSION_NOT_RELIABLE.equals(status)) { + try { + checkSessionLimits(); + } + catch (SessionLimitExceededException ex) { + // Ignore + } + if (this.limitExceeded) { + if (logger.isDebugEnabled()) { + logger.debug("Changing close status " + status + " to SESSION_NOT_RELIABLE."); + } + status = CloseStatus.SESSION_NOT_RELIABLE; + } + } + this.closeInProgress = true; + super.close(status); + } + finally { + this.closeLock.unlock(); + } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java index 6523624e1f..5a0010eebf 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java @@ -27,6 +27,7 @@ import org.junit.Test; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.WebSocketSession; import static org.junit.Assert.*; @@ -187,6 +188,56 @@ public class ConcurrentWebSocketSessionDecoratorTests { } } + @Test + public void closeStatusNormal() throws Exception { + + BlockingSession delegate = new BlockingSession(); + delegate.setOpen(true); + WebSocketSession decorator = new ConcurrentWebSocketSessionDecorator(delegate, 10 * 1000, 1024); + + decorator.close(CloseStatus.PROTOCOL_ERROR); + assertEquals(CloseStatus.PROTOCOL_ERROR, delegate.getCloseStatus()); + + decorator.close(CloseStatus.SERVER_ERROR); + assertEquals("Should have been ignored", CloseStatus.PROTOCOL_ERROR, delegate.getCloseStatus()); + } + + @Test + public void closeStatusChangesToSessionNotReliable() throws Exception { + + BlockingSession blockingSession = new BlockingSession(); + blockingSession.setId("123"); + blockingSession.setOpen(true); + CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); + + int sendTimeLimit = 100; + int bufferSizeLimit = 1024; + + final ConcurrentWebSocketSessionDecorator concurrentSession = + new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit); + + Executors.newSingleThreadExecutor().submit((Runnable) () -> { + TextMessage message = new TextMessage("slow message"); + try { + concurrentSession.sendMessage(message); + } + catch (IOException e) { + e.printStackTrace(); + } + }); + + assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); + + // ensure some send time elapses + Thread.sleep(sendTimeLimit + 100); + + concurrentSession.close(CloseStatus.PROTOCOL_ERROR); + + assertEquals("CloseStatus should have changed to SESSION_NOT_RELIABLE", + CloseStatus.SESSION_NOT_RELIABLE, blockingSession.getCloseStatus()); + } + + private static class BlockingSession extends TestWebSocketSession {