From ae75db265704103eab8bad8ca388d2d1cd9ed846 Mon Sep 17 00:00:00 2001 From: Benjamin Faal Date: Wed, 18 Nov 2020 18:41:59 +0100 Subject: [PATCH] Add allowedOriginPatterns to SockJS config See gh-26108 --- .../annotation/SockJsServiceRegistration.java | 15 +++++++ .../StompWebSocketEndpointRegistration.java | 7 ++++ ...MvcStompWebSocketEndpointRegistration.java | 22 +++++++++- .../support/OriginHandshakeInterceptor.java | 41 +++++++++++++++---- .../sockjs/support/AbstractSockJsService.java | 23 +++++++++++ ...ompWebSocketEndpointRegistrationTests.java | 26 ++++++++++++ 6 files changed, 126 insertions(+), 8 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java index 7b184f11bc..4342fd2bfb 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java @@ -73,6 +73,8 @@ public class SockJsServiceRegistration { private final List allowedOrigins = new ArrayList<>(); + private final List allowedOriginPatterns = new ArrayList<>(); + @Nullable private Boolean suppressCors; @@ -232,6 +234,18 @@ public class SockJsServiceRegistration { return this; } + /** + * Configure allowed {@code Origin} pattern header values. + * @since 5.3.2 + */ + protected SockJsServiceRegistration setAllowedOriginPatterns(String... allowedOriginPatterns) { + this.allowedOriginPatterns.clear(); + if (!ObjectUtils.isEmpty(allowedOriginPatterns)) { + this.allowedOriginPatterns.addAll(Arrays.asList(allowedOriginPatterns)); + } + return this; + } + /** * This option can be used to disable automatic addition of CORS headers for * SockJS requests. @@ -284,6 +298,7 @@ public class SockJsServiceRegistration { service.setSuppressCors(this.suppressCors); } service.setAllowedOrigins(this.allowedOrigins); + service.setAllowedOriginPatterns(this.allowedOriginPatterns); if (this.messageCodec != null) { service.setMessageCodec(this.messageCodec); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java index 4739649197..3fcb8caeee 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java @@ -61,4 +61,11 @@ public interface StompWebSocketEndpointRegistration { */ StompWebSocketEndpointRegistration setAllowedOrigins(String... origins); + /** + * Configure allowed {@code Origin} header values. + * + * @see org.springframework.web.cors.CorsConfiguration#setAllowedOriginPatterns(java.util.List) + */ + StompWebSocketEndpointRegistration setAllowedOriginPatterns(String... originPatterns); + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java index d62ba0596e..e9dcb84b7a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java @@ -58,6 +58,8 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE private final List allowedOrigins = new ArrayList<>(); + private final List allowedOriginPatterns = new ArrayList<>(); + @Nullable private SockJsServiceRegistration registration; @@ -97,6 +99,15 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE return this; } + @Override + public StompWebSocketEndpointRegistration setAllowedOriginPatterns(String... allowedOriginPatterns) { + this.allowedOriginPatterns.clear(); + if (!ObjectUtils.isEmpty(allowedOriginPatterns)) { + this.allowedOriginPatterns.addAll(Arrays.asList(allowedOriginPatterns)); + } + return this; + } + @Override public SockJsServiceRegistration withSockJS() { this.registration = new SockJsServiceRegistration(); @@ -112,13 +123,22 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE if (!this.allowedOrigins.isEmpty()) { this.registration.setAllowedOrigins(StringUtils.toStringArray(this.allowedOrigins)); } + if (!this.allowedOriginPatterns.isEmpty()) { + this.registration.setAllowedOriginPatterns(StringUtils.toStringArray(this.allowedOriginPatterns)); + } return this.registration; } protected HandshakeInterceptor[] getInterceptors() { List interceptors = new ArrayList<>(this.interceptors.size() + 1); interceptors.addAll(this.interceptors); - interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins)); + OriginHandshakeInterceptor originHandshakeInterceptor = new OriginHandshakeInterceptor(this.allowedOrigins); + interceptors.add(originHandshakeInterceptor); + + if (!ObjectUtils.isEmpty(this.allowedOriginPatterns)) { + originHandshakeInterceptor.setAllowedOriginPatterns(this.allowedOriginPatterns); + } + return interceptors.toArray(new HandshakeInterceptor[0]); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java index f10ec90f8d..12fbd8b6f3 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java @@ -16,11 +16,12 @@ package org.springframework.web.socket.server.support; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.LinkedHashSet; +import java.util.HashSet; +import java.util.List; import java.util.Map; -import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,6 +31,7 @@ import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.util.WebUtils; @@ -45,7 +47,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { protected final Log logger = LogFactory.getLog(getClass()); - private final Set allowedOrigins = new LinkedHashSet<>(); + private final CorsConfiguration corsConfiguration = new CorsConfiguration(); /** @@ -74,8 +76,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { */ public void setAllowedOrigins(Collection allowedOrigins) { Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null"); - this.allowedOrigins.clear(); - this.allowedOrigins.addAll(allowedOrigins); + this.corsConfiguration.setAllowedOrigins(new ArrayList<>(allowedOrigins)); } /** @@ -84,7 +85,33 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { * @see #setAllowedOrigins */ public Collection getAllowedOrigins() { - return Collections.unmodifiableSet(this.allowedOrigins); + if (this.corsConfiguration.getAllowedOrigins() == null) { + return Collections.emptyList(); + } + return Collections.unmodifiableSet(new HashSet<>(this.corsConfiguration.getAllowedOrigins())); + } + + /** + * Configure allowed {@code Origin} pattern header values. + * + * @see CorsConfiguration#setAllowedOriginPatterns(List) + */ + public void setAllowedOriginPatterns(Collection allowedOriginPatterns) { + Assert.notNull(allowedOriginPatterns, "Allowed origin patterns Collection must not be null"); + this.corsConfiguration.setAllowedOriginPatterns(new ArrayList<>(allowedOriginPatterns)); + } + + /** + * Return the allowed {@code Origin} pattern header values. + * + * @since 5.3.2 + * @see CorsConfiguration#getAllowedOriginPatterns() + */ + public Collection getAllowedOriginPatterns() { + if (this.corsConfiguration.getAllowedOriginPatterns() == null) { + return Collections.emptyList(); + } + return Collections.unmodifiableSet(new HashSet<>(this.corsConfiguration.getAllowedOriginPatterns())); } @@ -92,7 +119,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception { - if (!WebUtils.isSameOrigin(request) && !WebUtils.isValidOrigin(request, this.allowedOrigins)) { + if (!WebUtils.isSameOrigin(request) && this.corsConfiguration.checkOrigin(request.getHeaders().getOrigin()) == null) { response.setStatusCode(HttpStatus.FORBIDDEN); if (logger.isDebugEnabled()) { logger.debug("Handshake request rejected, Origin header value " + diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index 5dc84bbab0..7d04d82c02 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -99,6 +99,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig protected final Set allowedOrigins = new LinkedHashSet<>(); + protected final Set allowedOriginPatterns = new LinkedHashSet<>(); + private final SockJsRequestHandler infoHandler = new InfoHandler(); private final SockJsRequestHandler iframeHandler = new IframeHandler(); @@ -319,6 +321,17 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig this.allowedOrigins.addAll(allowedOrigins); } + /** + * Configure allowed {@code Origin} header values. + * + * @see org.springframework.web.cors.CorsConfiguration#setAllowedOriginPatterns(java.util.List) + */ + public void setAllowedOriginPatterns(Collection allowedOriginPatterns) { + Assert.notNull(allowedOriginPatterns, "Allowed origin patterns Collection must not be null"); + this.allowedOriginPatterns.clear(); + this.allowedOriginPatterns.addAll(allowedOriginPatterns); + } + /** * Return configure allowed {@code Origin} header values. * @since 4.1.2 @@ -328,6 +341,15 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig return Collections.unmodifiableSet(this.allowedOrigins); } + /** + * Return configure allowed {@code Origin} pattern header values. + * @since 5.3.2 + * @see #setAllowedOriginPatterns + */ + public Collection getAllowedOriginPatterns() { + return Collections.unmodifiableSet(this.allowedOriginPatterns); + } + /** * This method determines the SockJS path and handles SockJS static URLs. @@ -498,6 +520,7 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig if (!this.suppressCors && (request.getHeader(HttpHeaders.ORIGIN) != null)) { CorsConfiguration config = new CorsConfiguration(); config.setAllowedOrigins(new ArrayList<>(this.allowedOrigins)); + config.setAllowedOriginPatterns(new ArrayList<>(this.allowedOriginPatterns)); config.addAllowedMethod("*"); config.setAllowCredentials(true); config.setMaxAge(ONE_YEAR); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java index 9a5fe7945f..eeab9c04e4 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java @@ -135,6 +135,32 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { assertThat(sockJsService.shouldSuppressCors()).isFalse(); } + @Test + public void allowedOriginPatterns() { + WebMvcStompWebSocketEndpointRegistration registration = + new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); + + String origin = "https://*.mydomain.com"; + registration.setAllowedOriginPatterns(origin).withSockJS(); + + MultiValueMap mappings = registration.getMappings(); + assertThat(mappings.size()).isEqualTo(1); + SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); + assertThat(requestHandler.getSockJsService()).isNotNull(); + DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); + assertThat(sockJsService.getAllowedOriginPatterns().contains(origin)).isTrue(); + + registration = + new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); + registration.withSockJS().setAllowedOriginPatterns(origin); + mappings = registration.getMappings(); + assertThat(mappings.size()).isEqualTo(1); + requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); + assertThat(requestHandler.getSockJsService()).isNotNull(); + sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); + assertThat(sockJsService.getAllowedOriginPatterns().contains(origin)).isTrue(); + } + @Test // SPR-12283 public void disableCorsWithSockJsService() { WebMvcStompWebSocketEndpointRegistration registration =