Browse Source

Add allowedOriginPatterns to SockJS config

See gh-26108
pull/26122/head
Benjamin Faal 4 years ago committed by Rossen Stoyanchev
parent
commit
ae75db2657
  1. 15
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java
  2. 7
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java
  3. 22
      spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java
  4. 41
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java
  5. 23
      spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java
  6. 26
      spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java

15
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java

@ -73,6 +73,8 @@ public class SockJsServiceRegistration { @@ -73,6 +73,8 @@ public class SockJsServiceRegistration {
private final List<String> allowedOrigins = new ArrayList<>();
private final List<String> allowedOriginPatterns = new ArrayList<>();
@Nullable
private Boolean suppressCors;
@ -232,6 +234,18 @@ public class SockJsServiceRegistration { @@ -232,6 +234,18 @@ public class SockJsServiceRegistration {
return this;
}
/**
* Configure allowed {@code Origin} pattern header values.
* @since 5.3.2
*/
protected SockJsServiceRegistration setAllowedOriginPatterns(String... allowedOriginPatterns) {
this.allowedOriginPatterns.clear();
if (!ObjectUtils.isEmpty(allowedOriginPatterns)) {
this.allowedOriginPatterns.addAll(Arrays.asList(allowedOriginPatterns));
}
return this;
}
/**
* This option can be used to disable automatic addition of CORS headers for
* SockJS requests.
@ -284,6 +298,7 @@ public class SockJsServiceRegistration { @@ -284,6 +298,7 @@ public class SockJsServiceRegistration {
service.setSuppressCors(this.suppressCors);
}
service.setAllowedOrigins(this.allowedOrigins);
service.setAllowedOriginPatterns(this.allowedOriginPatterns);
if (this.messageCodec != null) {
service.setMessageCodec(this.messageCodec);

7
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java

@ -61,4 +61,11 @@ public interface StompWebSocketEndpointRegistration { @@ -61,4 +61,11 @@ public interface StompWebSocketEndpointRegistration {
*/
StompWebSocketEndpointRegistration setAllowedOrigins(String... origins);
/**
* Configure allowed {@code Origin} header values.
*
* @see org.springframework.web.cors.CorsConfiguration#setAllowedOriginPatterns(java.util.List)
*/
StompWebSocketEndpointRegistration setAllowedOriginPatterns(String... originPatterns);
}

22
spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java

@ -58,6 +58,8 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE @@ -58,6 +58,8 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
private final List<String> allowedOrigins = new ArrayList<>();
private final List<String> allowedOriginPatterns = new ArrayList<>();
@Nullable
private SockJsServiceRegistration registration;
@ -97,6 +99,15 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE @@ -97,6 +99,15 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
return this;
}
@Override
public StompWebSocketEndpointRegistration setAllowedOriginPatterns(String... allowedOriginPatterns) {
this.allowedOriginPatterns.clear();
if (!ObjectUtils.isEmpty(allowedOriginPatterns)) {
this.allowedOriginPatterns.addAll(Arrays.asList(allowedOriginPatterns));
}
return this;
}
@Override
public SockJsServiceRegistration withSockJS() {
this.registration = new SockJsServiceRegistration();
@ -112,13 +123,22 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE @@ -112,13 +123,22 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
if (!this.allowedOrigins.isEmpty()) {
this.registration.setAllowedOrigins(StringUtils.toStringArray(this.allowedOrigins));
}
if (!this.allowedOriginPatterns.isEmpty()) {
this.registration.setAllowedOriginPatterns(StringUtils.toStringArray(this.allowedOriginPatterns));
}
return this.registration;
}
protected HandshakeInterceptor[] getInterceptors() {
List<HandshakeInterceptor> interceptors = new ArrayList<>(this.interceptors.size() + 1);
interceptors.addAll(this.interceptors);
interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins));
OriginHandshakeInterceptor originHandshakeInterceptor = new OriginHandshakeInterceptor(this.allowedOrigins);
interceptors.add(originHandshakeInterceptor);
if (!ObjectUtils.isEmpty(this.allowedOriginPatterns)) {
originHandshakeInterceptor.setAllowedOriginPatterns(this.allowedOriginPatterns);
}
return interceptors.toArray(new HandshakeInterceptor[0]);
}

41
spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java

@ -16,11 +16,12 @@ @@ -16,11 +16,12 @@
package org.springframework.web.socket.server.support;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@ -30,6 +31,7 @@ import org.springframework.http.server.ServerHttpRequest; @@ -30,6 +31,7 @@ import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.WebUtils;
@ -45,7 +47,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { @@ -45,7 +47,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
protected final Log logger = LogFactory.getLog(getClass());
private final Set<String> allowedOrigins = new LinkedHashSet<>();
private final CorsConfiguration corsConfiguration = new CorsConfiguration();
/**
@ -74,8 +76,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { @@ -74,8 +76,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
*/
public void setAllowedOrigins(Collection<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null");
this.allowedOrigins.clear();
this.allowedOrigins.addAll(allowedOrigins);
this.corsConfiguration.setAllowedOrigins(new ArrayList<>(allowedOrigins));
}
/**
@ -84,7 +85,33 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { @@ -84,7 +85,33 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
* @see #setAllowedOrigins
*/
public Collection<String> getAllowedOrigins() {
return Collections.unmodifiableSet(this.allowedOrigins);
if (this.corsConfiguration.getAllowedOrigins() == null) {
return Collections.emptyList();
}
return Collections.unmodifiableSet(new HashSet<>(this.corsConfiguration.getAllowedOrigins()));
}
/**
* Configure allowed {@code Origin} pattern header values.
*
* @see CorsConfiguration#setAllowedOriginPatterns(List)
*/
public void setAllowedOriginPatterns(Collection<String> allowedOriginPatterns) {
Assert.notNull(allowedOriginPatterns, "Allowed origin patterns Collection must not be null");
this.corsConfiguration.setAllowedOriginPatterns(new ArrayList<>(allowedOriginPatterns));
}
/**
* Return the allowed {@code Origin} pattern header values.
*
* @since 5.3.2
* @see CorsConfiguration#getAllowedOriginPatterns()
*/
public Collection<String> getAllowedOriginPatterns() {
if (this.corsConfiguration.getAllowedOriginPatterns() == null) {
return Collections.emptyList();
}
return Collections.unmodifiableSet(new HashSet<>(this.corsConfiguration.getAllowedOriginPatterns()));
}
@ -92,7 +119,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { @@ -92,7 +119,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
if (!WebUtils.isSameOrigin(request) && !WebUtils.isValidOrigin(request, this.allowedOrigins)) {
if (!WebUtils.isSameOrigin(request) && this.corsConfiguration.checkOrigin(request.getHeaders().getOrigin()) == null) {
response.setStatusCode(HttpStatus.FORBIDDEN);
if (logger.isDebugEnabled()) {
logger.debug("Handshake request rejected, Origin header value " +

23
spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java

@ -99,6 +99,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig @@ -99,6 +99,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
protected final Set<String> allowedOrigins = new LinkedHashSet<>();
protected final Set<String> allowedOriginPatterns = new LinkedHashSet<>();
private final SockJsRequestHandler infoHandler = new InfoHandler();
private final SockJsRequestHandler iframeHandler = new IframeHandler();
@ -319,6 +321,17 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig @@ -319,6 +321,17 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
this.allowedOrigins.addAll(allowedOrigins);
}
/**
* Configure allowed {@code Origin} header values.
*
* @see org.springframework.web.cors.CorsConfiguration#setAllowedOriginPatterns(java.util.List)
*/
public void setAllowedOriginPatterns(Collection<String> allowedOriginPatterns) {
Assert.notNull(allowedOriginPatterns, "Allowed origin patterns Collection must not be null");
this.allowedOriginPatterns.clear();
this.allowedOriginPatterns.addAll(allowedOriginPatterns);
}
/**
* Return configure allowed {@code Origin} header values.
* @since 4.1.2
@ -328,6 +341,15 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig @@ -328,6 +341,15 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
return Collections.unmodifiableSet(this.allowedOrigins);
}
/**
* Return configure allowed {@code Origin} pattern header values.
* @since 5.3.2
* @see #setAllowedOriginPatterns
*/
public Collection<String> getAllowedOriginPatterns() {
return Collections.unmodifiableSet(this.allowedOriginPatterns);
}
/**
* This method determines the SockJS path and handles SockJS static URLs.
@ -498,6 +520,7 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig @@ -498,6 +520,7 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
if (!this.suppressCors && (request.getHeader(HttpHeaders.ORIGIN) != null)) {
CorsConfiguration config = new CorsConfiguration();
config.setAllowedOrigins(new ArrayList<>(this.allowedOrigins));
config.setAllowedOriginPatterns(new ArrayList<>(this.allowedOriginPatterns));
config.addAllowedMethod("*");
config.setAllowCredentials(true);
config.setMaxAge(ONE_YEAR);

26
spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java

@ -135,6 +135,32 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { @@ -135,6 +135,32 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
assertThat(sockJsService.shouldSuppressCors()).isFalse();
}
@Test
public void allowedOriginPatterns() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
String origin = "https://*.mydomain.com";
registration.setAllowedOriginPatterns(origin).withSockJS();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1);
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertThat(requestHandler.getSockJsService()).isNotNull();
DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
assertThat(sockJsService.getAllowedOriginPatterns().contains(origin)).isTrue();
registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.withSockJS().setAllowedOriginPatterns(origin);
mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1);
requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertThat(requestHandler.getSockJsService()).isNotNull();
sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
assertThat(sockJsService.getAllowedOriginPatterns().contains(origin)).isTrue();
}
@Test // SPR-12283
public void disableCorsWithSockJsService() {
WebMvcStompWebSocketEndpointRegistration registration =

Loading…
Cancel
Save