From 779779de7befc9ddda57d0edef55d3164b1cde5f Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 11 Apr 2017 17:41:41 -0400 Subject: [PATCH] Refactor use of TaskScheduler in WebSocket Java config Issue: SPR-15233 --- .../AbstractWebSocketHandlerRegistration.java | 31 ++++++- .../ServletWebSocketHandlerRegistration.java | 15 +++- .../ServletWebSocketHandlerRegistry.java | 51 ++++++++++-- .../annotation/SockJsServiceRegistration.java | 27 ++++++- ...MvcStompWebSocketEndpointRegistration.java | 3 +- .../WebSocketConfigurationSupport.java | 80 +++++++++++++++++-- .../WebSocketHandlerRegistrationTests.java | 27 ++++--- 7 files changed, 202 insertions(+), 32 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java index 8b9fb9512c..83c5652a53 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java @@ -43,8 +43,6 @@ import org.springframework.web.socket.sockjs.transport.handler.WebSocketTranspor */ public abstract class AbstractWebSocketHandlerRegistration implements WebSocketHandlerRegistration { - private final TaskScheduler sockJsTaskScheduler; - private final MultiValueMap handlerMap = new LinkedMultiValueMap<>(); private HandshakeHandler handshakeHandler; @@ -55,9 +53,21 @@ public abstract class AbstractWebSocketHandlerRegistration implements WebSock private SockJsServiceRegistration sockJsServiceRegistration; + private TaskScheduler scheduler; + + + public AbstractWebSocketHandlerRegistration() { + } + /** + * Deprecated constructor with a TaskScheduler. + * + * @deprecated as of 5.0 a TaskScheduler is not provided upfront, not until + * it is obvious that it is needed, see {@link #getSockJsServiceRegistration()}. + */ + @Deprecated public AbstractWebSocketHandlerRegistration(TaskScheduler defaultTaskScheduler) { - this.sockJsTaskScheduler = defaultTaskScheduler; + this.scheduler = defaultTaskScheduler; } @@ -98,7 +108,10 @@ public abstract class AbstractWebSocketHandlerRegistration implements WebSock @Override public SockJsServiceRegistration withSockJS() { - this.sockJsServiceRegistration = new SockJsServiceRegistration(this.sockJsTaskScheduler); + this.sockJsServiceRegistration = new SockJsServiceRegistration(); + if (this.scheduler != null) { + this.sockJsServiceRegistration.setTaskScheduler(this.scheduler); + } HandshakeInterceptor[] interceptors = getInterceptors(); if (interceptors.length > 0) { this.sockJsServiceRegistration.setInterceptors(interceptors); @@ -121,6 +134,16 @@ public abstract class AbstractWebSocketHandlerRegistration implements WebSock return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]); } + /** + * Expose the {@code SockJsServiceRegistration} -- if SockJS is enabled or + * {@code null} otherwise -- so that it can be configured with a TaskScheduler + * if the application did not provide one. This should be done prior to + * calling {@link #getMappings()}. + */ + protected SockJsServiceRegistration getSockJsServiceRegistration() { + return this.sockJsServiceRegistration; + } + protected final M getMappings() { M mappings = createMappings(); if (this.sockJsServiceRegistration != null) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistration.java index 609e07ab1d..b21934b43b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistration.java @@ -41,8 +41,19 @@ public class ServletWebSocketHandlerRegistration extends AbstractWebSocketHandlerRegistration> { - public ServletWebSocketHandlerRegistration(TaskScheduler sockJsTaskScheduler) { - super(sockJsTaskScheduler); + public ServletWebSocketHandlerRegistration() { + } + + /** + * Deprecated constructor with a TaskScheduler for SockJS use. + * + * @deprecated as of 5.0 a TaskScheduler is not provided upfront, not until + * it is obvious that it is needed, see {@link #getSockJsServiceRegistration()}. + */ + @Deprecated + @SuppressWarnings("deprecated") + public ServletWebSocketHandlerRegistration(TaskScheduler scheduler) { + super(scheduler); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistry.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistry.java index 8890fb8072..d25dce2092 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistry.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistry.java @@ -25,7 +25,6 @@ import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.util.MultiValueMap; import org.springframework.web.HttpRequestHandler; -import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.handler.AbstractHandlerMapping; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; import org.springframework.web.socket.WebSocketHandler; @@ -43,21 +42,33 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry private final List registrations = new ArrayList<>(4); - private TaskScheduler sockJsTaskScheduler; + private TaskScheduler scheduler; private int order = 1; private UrlPathHelper urlPathHelper; - public ServletWebSocketHandlerRegistry(ThreadPoolTaskScheduler sockJsTaskScheduler) { - this.sockJsTaskScheduler = sockJsTaskScheduler; + public ServletWebSocketHandlerRegistry() { + this.scheduler = null; } + /** + * Deprecated constructor with a TaskScheduler for SockJS use. + * + * @deprecated as of 5.0 a TaskScheduler is not provided upfront, not until + * it is obvious that it is needed, see {@link #requiresTaskScheduler()} and + * {@link #setTaskScheduler}. + */ + @Deprecated + public ServletWebSocketHandlerRegistry(ThreadPoolTaskScheduler scheduler) { + this.scheduler = scheduler; + } + + @Override public WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths) { - ServletWebSocketHandlerRegistration registration = - new ServletWebSocketHandlerRegistration(this.sockJsTaskScheduler); + ServletWebSocketHandlerRegistration registration = new ServletWebSocketHandlerRegistration(); registration.addHandler(handler, paths); this.registrations.add(registration); return registration; @@ -88,12 +99,31 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry return this.urlPathHelper; } + /** - * Return a {@link HandlerMapping} with mapped {@link HttpRequestHandler}s. + * Whether there are any endpoint SockJS registrations without a TaskScheduler. + * This method should be invoked just before {@link #getHandlerMapping()} to + * allow for registrations to be made first. */ + protected boolean requiresTaskScheduler() { + return this.registrations.stream() + .anyMatch(r -> r.getSockJsServiceRegistration() != null && + r.getSockJsServiceRegistration().getTaskScheduler() == null); + } + + /** + * Configure a TaskScheduler for SockJS endpoints. This should be configured + * before calling {@link #getHandlerMapping()} after checking if + * {@link #requiresTaskScheduler()} returns {@code true}. + */ + protected void setTaskScheduler(TaskScheduler scheduler) { + this.scheduler = scheduler; + } + public AbstractHandlerMapping getHandlerMapping() { Map urlMap = new LinkedHashMap<>(); for (ServletWebSocketHandlerRegistration registration : this.registrations) { + updateTaskScheduler(registration); MultiValueMap mappings = registration.getMappings(); for (HttpRequestHandler httpHandler : mappings.keySet()) { for (String pattern : mappings.get(httpHandler)) { @@ -110,4 +140,11 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry return hm; } + private void updateTaskScheduler(ServletWebSocketHandlerRegistration registration) { + SockJsServiceRegistration sockJsRegistration = registration.getSockJsServiceRegistration(); + if (sockJsRegistration != null && sockJsRegistration.getTaskScheduler() == null) { + sockJsRegistration.setTaskScheduler(this.scheduler); + } + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java index 2e832fd782..2e3ba2ed4a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java @@ -70,13 +70,29 @@ public class SockJsServiceRegistration { private SockJsMessageCodec messageCodec; + public SockJsServiceRegistration() { + } + + /** + * Deprecated constructor with a TaskScheduler. + * + * @deprecated as of 5.0 a TaskScheduler is not provided upfront, not until + * it is obvious that it is needed; call {@link #getTaskScheduler()} to check + * and then {@link #setTaskScheduler(TaskScheduler)} to set it before a call + * to {@link #createSockJsService()} + */ + @Deprecated public SockJsServiceRegistration(TaskScheduler defaultTaskScheduler) { this.scheduler = defaultTaskScheduler; } - public SockJsServiceRegistration setTaskScheduler(TaskScheduler taskScheduler) { - this.scheduler = taskScheduler; + /** + * A scheduler instance to use for scheduling SockJS heart-beats. + */ + public SockJsServiceRegistration setTaskScheduler(TaskScheduler scheduler) { + Assert.notNull(scheduler, "TaskScheduler is required."); + this.scheduler = scheduler; return this; } @@ -277,6 +293,13 @@ public class SockJsServiceRegistration { return service; } + /** + * Return the TaskScheduler, if configured. + */ + protected TaskScheduler getTaskScheduler() { + return this.scheduler; + } + private TransportHandlingSockJsService createSockJsService() { Assert.state(this.transportHandlers.isEmpty() || this.transportHandlerOverrides.isEmpty(), diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java index b1fc5535ab..fc4d701668 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java @@ -97,7 +97,8 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE @Override public SockJsServiceRegistration withSockJS() { - this.registration = new SockJsServiceRegistration(this.sockJsTaskScheduler); + this.registration = new SockJsServiceRegistration(); + this.registration.setTaskScheduler(this.sockJsTaskScheduler); HandshakeInterceptor[] interceptors = getInterceptors(); if (interceptors.length > 0) { this.registration.setInterceptors(interceptors); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java index 1b8de76bf3..6c16166820 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java @@ -16,7 +16,12 @@ package org.springframework.web.socket.config.annotation; +import java.util.Date; +import java.util.concurrent.ScheduledFuture; + import org.springframework.context.annotation.Bean; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.scheduling.Trigger; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.web.servlet.HandlerMapping; @@ -28,13 +33,28 @@ import org.springframework.web.servlet.HandlerMapping; */ public class WebSocketConfigurationSupport { + private ServletWebSocketHandlerRegistry handlerRegistry; + + private TaskScheduler scheduler; + + @Bean public HandlerMapping webSocketHandlerMapping() { - ServletWebSocketHandlerRegistry registry = new ServletWebSocketHandlerRegistry(defaultSockJsTaskScheduler()); - registerWebSocketHandlers(registry); + ServletWebSocketHandlerRegistry registry = initHandlerRegistry(); + if (registry.requiresTaskScheduler()) { + registry.setTaskScheduler(initTaskScheduler()); + } return registry.getHandlerMapping(); } + private ServletWebSocketHandlerRegistry initHandlerRegistry() { + if (this.handlerRegistry == null) { + this.handlerRegistry = new ServletWebSocketHandlerRegistry(); + registerWebSocketHandlers(this.handlerRegistry); + } + return this.handlerRegistry; + } + protected void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { } @@ -55,12 +75,58 @@ public class WebSocketConfigurationSupport { * */ @Bean - public ThreadPoolTaskScheduler defaultSockJsTaskScheduler() { - ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); - scheduler.setThreadNamePrefix("SockJS-"); - scheduler.setPoolSize(Runtime.getRuntime().availableProcessors()); - scheduler.setRemoveOnCancelPolicy(true); + public TaskScheduler defaultSockJsTaskScheduler() { + return initTaskScheduler(); + } + + private TaskScheduler initTaskScheduler() { + if (this.scheduler == null) { + ServletWebSocketHandlerRegistry registry = initHandlerRegistry(); + if (registry.requiresTaskScheduler()) { + ThreadPoolTaskScheduler threadPoolScheduler = new ThreadPoolTaskScheduler(); + threadPoolScheduler.setThreadNamePrefix("SockJS-"); + threadPoolScheduler.setPoolSize(Runtime.getRuntime().availableProcessors()); + threadPoolScheduler.setRemoveOnCancelPolicy(true); + this.scheduler = threadPoolScheduler; + } + else { + this.scheduler = new NoOpScheduler(); + } + } return scheduler; } + + private static class NoOpScheduler implements TaskScheduler { + + @Override + public ScheduledFuture schedule(Runnable task, Trigger trigger) { + throw new IllegalStateException("Unexpected use of scheduler."); + } + + @Override + public ScheduledFuture schedule(Runnable task, Date startTime) { + throw new IllegalStateException("Unexpected use of scheduler."); + } + + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable task, Date startTime, long period) { + throw new IllegalStateException("Unexpected use of scheduler."); + } + + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable task, long period) { + throw new IllegalStateException("Unexpected use of scheduler."); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable task, Date startTime, long delay) { + throw new IllegalStateException("Unexpected use of scheduler."); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable task, long delay) { + throw new IllegalStateException("Unexpected use of scheduler."); + } + } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java index 3a18fb7d13..eb754e4229 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java @@ -36,7 +36,10 @@ import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; /** * Test fixture for @@ -54,7 +57,7 @@ public class WebSocketHandlerRegistrationTests { @Before public void setup() { this.taskScheduler = Mockito.mock(TaskScheduler.class); - this.registration = new TestWebSocketHandlerRegistration(taskScheduler); + this.registration = new TestWebSocketHandlerRegistration(); } @Test @@ -68,12 +71,14 @@ public class WebSocketHandlerRegistrationTests { Mapping m1 = mappings.get(0); assertEquals(handler, m1.webSocketHandler); assertEquals("/foo", m1.path); + assertNotNull(m1.interceptors); assertEquals(1, m1.interceptors.length); assertEquals(OriginHandshakeInterceptor.class, m1.interceptors[0].getClass()); Mapping m2 = mappings.get(1); assertEquals(handler, m2.webSocketHandler); assertEquals("/bar", m2.path); + assertNotNull(m2.interceptors); assertEquals(1, m2.interceptors.length); assertEquals(OriginHandshakeInterceptor.class, m2.interceptors[0].getClass()); } @@ -91,6 +96,7 @@ public class WebSocketHandlerRegistrationTests { Mapping mapping = mappings.get(0); assertEquals(handler, mapping.webSocketHandler); assertEquals("/foo", mapping.path); + assertNotNull(mapping.interceptors); assertEquals(2, mapping.interceptors.length); assertEquals(interceptor, mapping.interceptors[0]); assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass()); @@ -109,6 +115,7 @@ public class WebSocketHandlerRegistrationTests { Mapping mapping = mappings.get(0); assertEquals(handler, mapping.webSocketHandler); assertEquals("/foo", mapping.path); + assertNotNull(mapping.interceptors); assertEquals(2, mapping.interceptors.length); assertEquals(interceptor, mapping.interceptors[0]); assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass()); @@ -127,6 +134,7 @@ public class WebSocketHandlerRegistrationTests { Mapping mapping = mappings.get(0); assertEquals(handler, mapping.webSocketHandler); assertEquals("/foo", mapping.path); + assertNotNull(mapping.interceptors); assertEquals(2, mapping.interceptors.length); assertEquals(interceptor, mapping.interceptors[0]); assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass()); @@ -137,8 +145,12 @@ public class WebSocketHandlerRegistrationTests { WebSocketHandler handler = new TextWebSocketHandler(); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); - this.registration.addHandler(handler, "/foo").addInterceptors(interceptor) - .setAllowedOrigins("http://mydomain1.com").withSockJS(); + this.registration.addHandler(handler, "/foo") + .addInterceptors(interceptor) + .setAllowedOrigins("http://mydomain1.com") + .withSockJS(); + + this.registration.getSockJsServiceRegistration().setTaskScheduler(this.taskScheduler); List mappings = this.registration.getMappings(); assertEquals(1, mappings.size()); @@ -175,6 +187,7 @@ public class WebSocketHandlerRegistrationTests { HandshakeHandler handshakeHandler = new DefaultHandshakeHandler(); this.registration.addHandler(handler, "/foo").setHandshakeHandler(handshakeHandler).withSockJS(); + this.registration.getSockJsServiceRegistration().setTaskScheduler(this.taskScheduler); List mappings = this.registration.getMappings(); assertEquals(1, mappings.size()); @@ -190,11 +203,7 @@ public class WebSocketHandlerRegistrationTests { } - private static class TestWebSocketHandlerRegistration extends AbstractWebSocketHandlerRegistration> { - - public TestWebSocketHandlerRegistration(TaskScheduler sockJsTaskScheduler) { - super(sockJsTaskScheduler); - } + private static class TestWebSocketHandlerRegistration extends AbstractWebSocketHandlerRegistration> { @Override protected List createMappings() {