Browse Source

Refactor use of TaskScheduler in WebSocket Java config

Issue: SPR-15233
pull/1386/head
Rossen Stoyanchev 8 years ago
parent
commit
779779de7b
  1. 31
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java
  2. 15
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistration.java
  3. 51
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistry.java
  4. 27
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java
  5. 3
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java
  6. 80
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java
  7. 27
      spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java

31
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java

@ -43,8 +43,6 @@ import org.springframework.web.socket.sockjs.transport.handler.WebSocketTranspor @@ -43,8 +43,6 @@ import org.springframework.web.socket.sockjs.transport.handler.WebSocketTranspor
*/
public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSocketHandlerRegistration {
private final TaskScheduler sockJsTaskScheduler;
private final MultiValueMap<WebSocketHandler, String> handlerMap = new LinkedMultiValueMap<>();
private HandshakeHandler handshakeHandler;
@ -55,9 +53,21 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock @@ -55,9 +53,21 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
private SockJsServiceRegistration sockJsServiceRegistration;
private TaskScheduler scheduler;
public AbstractWebSocketHandlerRegistration() {
}
/**
* Deprecated constructor with a TaskScheduler.
*
* @deprecated as of 5.0 a TaskScheduler is not provided upfront, not until
* it is obvious that it is needed, see {@link #getSockJsServiceRegistration()}.
*/
@Deprecated
public AbstractWebSocketHandlerRegistration(TaskScheduler defaultTaskScheduler) {
this.sockJsTaskScheduler = defaultTaskScheduler;
this.scheduler = defaultTaskScheduler;
}
@ -98,7 +108,10 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock @@ -98,7 +108,10 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
@Override
public SockJsServiceRegistration withSockJS() {
this.sockJsServiceRegistration = new SockJsServiceRegistration(this.sockJsTaskScheduler);
this.sockJsServiceRegistration = new SockJsServiceRegistration();
if (this.scheduler != null) {
this.sockJsServiceRegistration.setTaskScheduler(this.scheduler);
}
HandshakeInterceptor[] interceptors = getInterceptors();
if (interceptors.length > 0) {
this.sockJsServiceRegistration.setInterceptors(interceptors);
@ -121,6 +134,16 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock @@ -121,6 +134,16 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]);
}
/**
* Expose the {@code SockJsServiceRegistration} -- if SockJS is enabled or
* {@code null} otherwise -- so that it can be configured with a TaskScheduler
* if the application did not provide one. This should be done prior to
* calling {@link #getMappings()}.
*/
protected SockJsServiceRegistration getSockJsServiceRegistration() {
return this.sockJsServiceRegistration;
}
protected final M getMappings() {
M mappings = createMappings();
if (this.sockJsServiceRegistration != null) {

15
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistration.java

@ -41,8 +41,19 @@ public class ServletWebSocketHandlerRegistration @@ -41,8 +41,19 @@ public class ServletWebSocketHandlerRegistration
extends AbstractWebSocketHandlerRegistration<MultiValueMap<HttpRequestHandler, String>> {
public ServletWebSocketHandlerRegistration(TaskScheduler sockJsTaskScheduler) {
super(sockJsTaskScheduler);
public ServletWebSocketHandlerRegistration() {
}
/**
* Deprecated constructor with a TaskScheduler for SockJS use.
*
* @deprecated as of 5.0 a TaskScheduler is not provided upfront, not until
* it is obvious that it is needed, see {@link #getSockJsServiceRegistration()}.
*/
@Deprecated
@SuppressWarnings("deprecated")
public ServletWebSocketHandlerRegistration(TaskScheduler scheduler) {
super(scheduler);
}

51
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/ServletWebSocketHandlerRegistry.java

@ -25,7 +25,6 @@ import org.springframework.scheduling.TaskScheduler; @@ -25,7 +25,6 @@ import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
@ -43,21 +42,33 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry @@ -43,21 +42,33 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry
private final List<ServletWebSocketHandlerRegistration> registrations = new ArrayList<>(4);
private TaskScheduler sockJsTaskScheduler;
private TaskScheduler scheduler;
private int order = 1;
private UrlPathHelper urlPathHelper;
public ServletWebSocketHandlerRegistry(ThreadPoolTaskScheduler sockJsTaskScheduler) {
this.sockJsTaskScheduler = sockJsTaskScheduler;
public ServletWebSocketHandlerRegistry() {
this.scheduler = null;
}
/**
* Deprecated constructor with a TaskScheduler for SockJS use.
*
* @deprecated as of 5.0 a TaskScheduler is not provided upfront, not until
* it is obvious that it is needed, see {@link #requiresTaskScheduler()} and
* {@link #setTaskScheduler}.
*/
@Deprecated
public ServletWebSocketHandlerRegistry(ThreadPoolTaskScheduler scheduler) {
this.scheduler = scheduler;
}
@Override
public WebSocketHandlerRegistration addHandler(WebSocketHandler handler, String... paths) {
ServletWebSocketHandlerRegistration registration =
new ServletWebSocketHandlerRegistration(this.sockJsTaskScheduler);
ServletWebSocketHandlerRegistration registration = new ServletWebSocketHandlerRegistration();
registration.addHandler(handler, paths);
this.registrations.add(registration);
return registration;
@ -88,12 +99,31 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry @@ -88,12 +99,31 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry
return this.urlPathHelper;
}
/**
* Return a {@link HandlerMapping} with mapped {@link HttpRequestHandler}s.
* Whether there are any endpoint SockJS registrations without a TaskScheduler.
* This method should be invoked just before {@link #getHandlerMapping()} to
* allow for registrations to be made first.
*/
protected boolean requiresTaskScheduler() {
return this.registrations.stream()
.anyMatch(r -> r.getSockJsServiceRegistration() != null &&
r.getSockJsServiceRegistration().getTaskScheduler() == null);
}
/**
* Configure a TaskScheduler for SockJS endpoints. This should be configured
* before calling {@link #getHandlerMapping()} after checking if
* {@link #requiresTaskScheduler()} returns {@code true}.
*/
protected void setTaskScheduler(TaskScheduler scheduler) {
this.scheduler = scheduler;
}
public AbstractHandlerMapping getHandlerMapping() {
Map<String, Object> urlMap = new LinkedHashMap<>();
for (ServletWebSocketHandlerRegistration registration : this.registrations) {
updateTaskScheduler(registration);
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
for (HttpRequestHandler httpHandler : mappings.keySet()) {
for (String pattern : mappings.get(httpHandler)) {
@ -110,4 +140,11 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry @@ -110,4 +140,11 @@ public class ServletWebSocketHandlerRegistry implements WebSocketHandlerRegistry
return hm;
}
private void updateTaskScheduler(ServletWebSocketHandlerRegistration registration) {
SockJsServiceRegistration sockJsRegistration = registration.getSockJsServiceRegistration();
if (sockJsRegistration != null && sockJsRegistration.getTaskScheduler() == null) {
sockJsRegistration.setTaskScheduler(this.scheduler);
}
}
}

27
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java

@ -70,13 +70,29 @@ public class SockJsServiceRegistration { @@ -70,13 +70,29 @@ public class SockJsServiceRegistration {
private SockJsMessageCodec messageCodec;
public SockJsServiceRegistration() {
}
/**
* Deprecated constructor with a TaskScheduler.
*
* @deprecated as of 5.0 a TaskScheduler is not provided upfront, not until
* it is obvious that it is needed; call {@link #getTaskScheduler()} to check
* and then {@link #setTaskScheduler(TaskScheduler)} to set it before a call
* to {@link #createSockJsService()}
*/
@Deprecated
public SockJsServiceRegistration(TaskScheduler defaultTaskScheduler) {
this.scheduler = defaultTaskScheduler;
}
public SockJsServiceRegistration setTaskScheduler(TaskScheduler taskScheduler) {
this.scheduler = taskScheduler;
/**
* A scheduler instance to use for scheduling SockJS heart-beats.
*/
public SockJsServiceRegistration setTaskScheduler(TaskScheduler scheduler) {
Assert.notNull(scheduler, "TaskScheduler is required.");
this.scheduler = scheduler;
return this;
}
@ -277,6 +293,13 @@ public class SockJsServiceRegistration { @@ -277,6 +293,13 @@ public class SockJsServiceRegistration {
return service;
}
/**
* Return the TaskScheduler, if configured.
*/
protected TaskScheduler getTaskScheduler() {
return this.scheduler;
}
private TransportHandlingSockJsService createSockJsService() {
Assert.state(this.transportHandlers.isEmpty() || this.transportHandlerOverrides.isEmpty(),

3
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java

@ -97,7 +97,8 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE @@ -97,7 +97,8 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
@Override
public SockJsServiceRegistration withSockJS() {
this.registration = new SockJsServiceRegistration(this.sockJsTaskScheduler);
this.registration = new SockJsServiceRegistration();
this.registration.setTaskScheduler(this.sockJsTaskScheduler);
HandshakeInterceptor[] interceptors = getInterceptors();
if (interceptors.length > 0) {
this.registration.setInterceptors(interceptors);

80
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java

@ -16,7 +16,12 @@ @@ -16,7 +16,12 @@
package org.springframework.web.socket.config.annotation;
import java.util.Date;
import java.util.concurrent.ScheduledFuture;
import org.springframework.context.annotation.Bean;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.Trigger;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.web.servlet.HandlerMapping;
@ -28,13 +33,28 @@ import org.springframework.web.servlet.HandlerMapping; @@ -28,13 +33,28 @@ import org.springframework.web.servlet.HandlerMapping;
*/
public class WebSocketConfigurationSupport {
private ServletWebSocketHandlerRegistry handlerRegistry;
private TaskScheduler scheduler;
@Bean
public HandlerMapping webSocketHandlerMapping() {
ServletWebSocketHandlerRegistry registry = new ServletWebSocketHandlerRegistry(defaultSockJsTaskScheduler());
registerWebSocketHandlers(registry);
ServletWebSocketHandlerRegistry registry = initHandlerRegistry();
if (registry.requiresTaskScheduler()) {
registry.setTaskScheduler(initTaskScheduler());
}
return registry.getHandlerMapping();
}
private ServletWebSocketHandlerRegistry initHandlerRegistry() {
if (this.handlerRegistry == null) {
this.handlerRegistry = new ServletWebSocketHandlerRegistry();
registerWebSocketHandlers(this.handlerRegistry);
}
return this.handlerRegistry;
}
protected void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
}
@ -55,12 +75,58 @@ public class WebSocketConfigurationSupport { @@ -55,12 +75,58 @@ public class WebSocketConfigurationSupport {
* </pre>
*/
@Bean
public ThreadPoolTaskScheduler defaultSockJsTaskScheduler() {
ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
scheduler.setThreadNamePrefix("SockJS-");
scheduler.setPoolSize(Runtime.getRuntime().availableProcessors());
scheduler.setRemoveOnCancelPolicy(true);
public TaskScheduler defaultSockJsTaskScheduler() {
return initTaskScheduler();
}
private TaskScheduler initTaskScheduler() {
if (this.scheduler == null) {
ServletWebSocketHandlerRegistry registry = initHandlerRegistry();
if (registry.requiresTaskScheduler()) {
ThreadPoolTaskScheduler threadPoolScheduler = new ThreadPoolTaskScheduler();
threadPoolScheduler.setThreadNamePrefix("SockJS-");
threadPoolScheduler.setPoolSize(Runtime.getRuntime().availableProcessors());
threadPoolScheduler.setRemoveOnCancelPolicy(true);
this.scheduler = threadPoolScheduler;
}
else {
this.scheduler = new NoOpScheduler();
}
}
return scheduler;
}
private static class NoOpScheduler implements TaskScheduler {
@Override
public ScheduledFuture<?> schedule(Runnable task, Trigger trigger) {
throw new IllegalStateException("Unexpected use of scheduler.");
}
@Override
public ScheduledFuture<?> schedule(Runnable task, Date startTime) {
throw new IllegalStateException("Unexpected use of scheduler.");
}
@Override
public ScheduledFuture<?> scheduleAtFixedRate(Runnable task, Date startTime, long period) {
throw new IllegalStateException("Unexpected use of scheduler.");
}
@Override
public ScheduledFuture<?> scheduleAtFixedRate(Runnable task, long period) {
throw new IllegalStateException("Unexpected use of scheduler.");
}
@Override
public ScheduledFuture<?> scheduleWithFixedDelay(Runnable task, Date startTime, long delay) {
throw new IllegalStateException("Unexpected use of scheduler.");
}
@Override
public ScheduledFuture<?> scheduleWithFixedDelay(Runnable task, long delay) {
throw new IllegalStateException("Unexpected use of scheduler.");
}
}
}

27
spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java

@ -36,7 +36,10 @@ import org.springframework.web.socket.sockjs.transport.TransportType; @@ -36,7 +36,10 @@ import org.springframework.web.socket.sockjs.transport.TransportType;
import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
/**
* Test fixture for
@ -54,7 +57,7 @@ public class WebSocketHandlerRegistrationTests { @@ -54,7 +57,7 @@ public class WebSocketHandlerRegistrationTests {
@Before
public void setup() {
this.taskScheduler = Mockito.mock(TaskScheduler.class);
this.registration = new TestWebSocketHandlerRegistration(taskScheduler);
this.registration = new TestWebSocketHandlerRegistration();
}
@Test
@ -68,12 +71,14 @@ public class WebSocketHandlerRegistrationTests { @@ -68,12 +71,14 @@ public class WebSocketHandlerRegistrationTests {
Mapping m1 = mappings.get(0);
assertEquals(handler, m1.webSocketHandler);
assertEquals("/foo", m1.path);
assertNotNull(m1.interceptors);
assertEquals(1, m1.interceptors.length);
assertEquals(OriginHandshakeInterceptor.class, m1.interceptors[0].getClass());
Mapping m2 = mappings.get(1);
assertEquals(handler, m2.webSocketHandler);
assertEquals("/bar", m2.path);
assertNotNull(m2.interceptors);
assertEquals(1, m2.interceptors.length);
assertEquals(OriginHandshakeInterceptor.class, m2.interceptors[0].getClass());
}
@ -91,6 +96,7 @@ public class WebSocketHandlerRegistrationTests { @@ -91,6 +96,7 @@ public class WebSocketHandlerRegistrationTests {
Mapping mapping = mappings.get(0);
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo", mapping.path);
assertNotNull(mapping.interceptors);
assertEquals(2, mapping.interceptors.length);
assertEquals(interceptor, mapping.interceptors[0]);
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
@ -109,6 +115,7 @@ public class WebSocketHandlerRegistrationTests { @@ -109,6 +115,7 @@ public class WebSocketHandlerRegistrationTests {
Mapping mapping = mappings.get(0);
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo", mapping.path);
assertNotNull(mapping.interceptors);
assertEquals(2, mapping.interceptors.length);
assertEquals(interceptor, mapping.interceptors[0]);
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
@ -127,6 +134,7 @@ public class WebSocketHandlerRegistrationTests { @@ -127,6 +134,7 @@ public class WebSocketHandlerRegistrationTests {
Mapping mapping = mappings.get(0);
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo", mapping.path);
assertNotNull(mapping.interceptors);
assertEquals(2, mapping.interceptors.length);
assertEquals(interceptor, mapping.interceptors[0]);
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
@ -137,8 +145,12 @@ public class WebSocketHandlerRegistrationTests { @@ -137,8 +145,12 @@ public class WebSocketHandlerRegistrationTests {
WebSocketHandler handler = new TextWebSocketHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
this.registration.addHandler(handler, "/foo").addInterceptors(interceptor)
.setAllowedOrigins("http://mydomain1.com").withSockJS();
this.registration.addHandler(handler, "/foo")
.addInterceptors(interceptor)
.setAllowedOrigins("http://mydomain1.com")
.withSockJS();
this.registration.getSockJsServiceRegistration().setTaskScheduler(this.taskScheduler);
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
@ -175,6 +187,7 @@ public class WebSocketHandlerRegistrationTests { @@ -175,6 +187,7 @@ public class WebSocketHandlerRegistrationTests {
HandshakeHandler handshakeHandler = new DefaultHandshakeHandler();
this.registration.addHandler(handler, "/foo").setHandshakeHandler(handshakeHandler).withSockJS();
this.registration.getSockJsServiceRegistration().setTaskScheduler(this.taskScheduler);
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
@ -190,11 +203,7 @@ public class WebSocketHandlerRegistrationTests { @@ -190,11 +203,7 @@ public class WebSocketHandlerRegistrationTests {
}
private static class TestWebSocketHandlerRegistration extends AbstractWebSocketHandlerRegistration<List<Mapping>> {
public TestWebSocketHandlerRegistration(TaskScheduler sockJsTaskScheduler) {
super(sockJsTaskScheduler);
}
private static class TestWebSocketHandlerRegistration extends AbstractWebSocketHandlerRegistration<List<Mapping>> {
@Override
protected List<Mapping> createMappings() {

Loading…
Cancel
Save