Browse Source

CorsConfiguration now supports pattern based origins.

Closes gh-24763
pull/25381/head
Ruslan Akhundov 5 years ago committed by Rossen Stoyanchev
parent
commit
8632118e8d
  1. 20
      spring-web/src/main/java/org/springframework/web/bind/annotation/CrossOrigin.java
  2. 120
      spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java
  3. 84
      spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java
  4. 3
      spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java
  5. 15
      spring-webflux/src/test/java/org/springframework/web/reactive/handler/CorsUrlHandlerMappingTests.java
  6. 33
      spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/CrossOriginAnnotationIntegrationTests.java
  7. 3
      spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMapping.java
  8. 14
      spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java
  9. 35
      spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java

20
spring-web/src/main/java/org/springframework/web/bind/annotation/CrossOrigin.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -45,6 +45,7 @@ import org.springframework.web.cors.CorsConfiguration; @@ -45,6 +45,7 @@ import org.springframework.web.cors.CorsConfiguration;
* @author Russell Allen
* @author Sebastien Deleuze
* @author Sam Brannen
* @author Ruslan Akhundov
* @since 4.2
*/
@Target({ElementType.TYPE, ElementType.METHOD})
@ -93,6 +94,23 @@ public @interface CrossOrigin { @@ -93,6 +94,23 @@ public @interface CrossOrigin {
@AliasFor("value")
String[] origins() default {};
/**
* The list of allowed origins patterns that be specific origins, e.g.
* {@code ".*\.domain1\.com"}, or {@code ".*"} for matching all origins.
* <p>A matched origin is listed in the {@code Access-Control-Allow-Origin}
* response header of preflight actual CORS requests.
* <p>By default all origins are allowed.
* <p><strong>Note:</strong> CORS checks use values from "Forwarded"
* (<a href="https://tools.ietf.org/html/rfc7239">RFC 7239</a>),
* "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" headers,
* if present, in order to reflect the client-originated address.
* Consider using the {@code ForwardedHeaderFilter} in order to choose from a
* central place whether to extract and use, or to discard such headers.
* See the Spring Framework reference for more on this filter.
* @see #value
*/
String[] originsPatterns() default {};
/**
* The list of request headers that are permitted in actual requests,
* possibly {@code "*"} to allow all headers.

120
spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,6 +23,7 @@ import java.util.Collections; @@ -23,6 +23,7 @@ import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.springframework.http.HttpMethod;
@ -45,6 +46,7 @@ import org.springframework.util.StringUtils; @@ -45,6 +46,7 @@ import org.springframework.util.StringUtils;
* @author Rossen Stoyanchev
* @author Juergen Hoeller
* @author Sam Brannen
* @author Ruslan Akhundov
* @since 4.2
* @see <a href="https://www.w3.org/TR/cors/">CORS spec</a>
*/
@ -52,6 +54,8 @@ public class CorsConfiguration { @@ -52,6 +54,8 @@ public class CorsConfiguration {
/** Wildcard representing <em>all</em> origins, methods, or headers. */
public static final String ALL = "*";
/** Wildcard representing pattern that matches <em>all</em> origins. */
public static final String ALL_PATTERN = ".*";
private static final List<HttpMethod> DEFAULT_METHODS = Collections.unmodifiableList(
Arrays.asList(HttpMethod.GET, HttpMethod.HEAD));
@ -62,10 +66,19 @@ public class CorsConfiguration { @@ -62,10 +66,19 @@ public class CorsConfiguration {
private static final List<String> DEFAULT_PERMIT_ALL = Collections.unmodifiableList(
Collections.singletonList(ALL));
private static final List<String> DEFAULT_PERMIT_ALL_PATTERN_STR = Collections.unmodifiableList(
Collections.singletonList(ALL_PATTERN));
private static final List<Pattern> DEFAULT_PERMIT_ALL_PATTERN = Collections.unmodifiableList(
Collections.singletonList(Pattern.compile(ALL_PATTERN)));
@Nullable
private List<String> allowedOrigins;
@Nullable
private List<Pattern> allowedOriginsPatterns;
@Nullable
private List<String> allowedMethods;
@ -99,6 +112,7 @@ public class CorsConfiguration { @@ -99,6 +112,7 @@ public class CorsConfiguration {
*/
public CorsConfiguration(CorsConfiguration other) {
this.allowedOrigins = other.allowedOrigins;
this.allowedOriginsPatterns = other.allowedOriginsPatterns;
this.allowedMethods = other.allowedMethods;
this.resolvedMethods = other.resolvedMethods;
this.allowedHeaders = other.allowedHeaders;
@ -140,6 +154,54 @@ public class CorsConfiguration { @@ -140,6 +154,54 @@ public class CorsConfiguration {
this.allowedOrigins.add(origin);
}
/**
* Set the origins patterns to allow, e.g. {@code "*.com"}.
* <p>By default this is not set.
*/
public CorsConfiguration setAllowedOriginsPatterns(@Nullable List<String> allowedOriginsPatterns) {
if (allowedOriginsPatterns == null) {
this.allowedOriginsPatterns = null;
}
else {
this.allowedOriginsPatterns = new ArrayList<>(allowedOriginsPatterns.size());
for (String pattern : allowedOriginsPatterns) {
this.allowedOriginsPatterns.add(Pattern.compile(pattern));
}
}
return this;
}
/**
* Return the configured origins patterns to allow, or {@code null} if none.
*
* @see #addAllowedOriginPattern(String)
* @see #setAllowedOriginsPatterns(List)
*/
@Nullable
public List<String> getAllowedOriginsPatterns() {
if (this.allowedOriginsPatterns == null) {
return null;
}
if (this.allowedOriginsPatterns == DEFAULT_PERMIT_ALL_PATTERN) {
return DEFAULT_PERMIT_ALL_PATTERN_STR;
}
return this.allowedOriginsPatterns.stream().map(Pattern::toString).collect(Collectors.toList());
}
/**
* Add an origin pattern to allow.
*/
public void addAllowedOriginPattern(String originPattern) {
if (this.allowedOriginsPatterns == null) {
this.allowedOriginsPatterns = new ArrayList<>(4);
}
else if (this.allowedOriginsPatterns == DEFAULT_PERMIT_ALL_PATTERN) {
setAllowedOriginsPatterns(DEFAULT_PERMIT_ALL_PATTERN_STR);
}
this.allowedOriginsPatterns.add(Pattern.compile(originPattern));
}
/**
* Set the HTTP methods to allow, e.g. {@code "GET"}, {@code "POST"},
* {@code "PUT"}, etc.
@ -351,7 +413,7 @@ public class CorsConfiguration { @@ -351,7 +413,7 @@ public class CorsConfiguration {
* </ul>
*/
public CorsConfiguration applyPermitDefaultValues() {
if (this.allowedOrigins == null) {
if (this.allowedOrigins == null && this.allowedOriginsPatterns == null) {
this.allowedOrigins = DEFAULT_PERMIT_ALL;
}
if (this.allowedMethods == null) {
@ -392,7 +454,14 @@ public class CorsConfiguration { @@ -392,7 +454,14 @@ public class CorsConfiguration {
return this;
}
CorsConfiguration config = new CorsConfiguration(this);
config.setAllowedOrigins(combine(getAllowedOrigins(), other.getAllowedOrigins()));
List<String> combinedOrigins = combine(getAllowedOrigins(), other.getAllowedOrigins());
List<String> combinedOriginsPatterns = combine(getAllowedOriginsPatterns(), other.getAllowedOriginsPatterns());
if (combinedOrigins == DEFAULT_PERMIT_ALL && combinedOriginsPatterns != DEFAULT_PERMIT_ALL_PATTERN_STR
&& !CollectionUtils.isEmpty(combinedOriginsPatterns)) {
combinedOrigins = null;
}
config.setAllowedOrigins(combinedOrigins);
config.setAllowedOriginsPatterns(combinedOriginsPatterns);
config.setAllowedMethods(combine(getAllowedMethods(), other.getAllowedMethods()));
config.setAllowedHeaders(combine(getAllowedHeaders(), other.getAllowedHeaders()));
config.setExposedHeaders(combine(getExposedHeaders(), other.getExposedHeaders()));
@ -414,15 +483,20 @@ public class CorsConfiguration { @@ -414,15 +483,20 @@ public class CorsConfiguration {
if (source == null) {
return other;
}
if (source == DEFAULT_PERMIT_ALL || source == DEFAULT_PERMIT_METHODS) {
if (source == DEFAULT_PERMIT_ALL || source == DEFAULT_PERMIT_METHODS
|| source == DEFAULT_PERMIT_ALL_PATTERN_STR) {
return other;
}
if (other == DEFAULT_PERMIT_ALL || other == DEFAULT_PERMIT_METHODS) {
if (other == DEFAULT_PERMIT_ALL || other == DEFAULT_PERMIT_METHODS
|| other == DEFAULT_PERMIT_ALL_PATTERN_STR) {
return source;
}
if (source.contains(ALL) || other.contains(ALL)) {
return new ArrayList<>(Collections.singletonList(ALL));
}
if ( source.contains(ALL_PATTERN) || other.contains(ALL_PATTERN)) {
return new ArrayList<>(Collections.singletonList(ALL_PATTERN));
}
Set<String> combined = new LinkedHashSet<>(source);
combined.addAll(other);
return new ArrayList<>(combined);
@ -439,21 +513,35 @@ public class CorsConfiguration { @@ -439,21 +513,35 @@ public class CorsConfiguration {
if (!StringUtils.hasText(requestOrigin)) {
return null;
}
if (ObjectUtils.isEmpty(this.allowedOrigins)) {
return null;
}
if (this.allowedOrigins.contains(ALL)) {
if (this.allowCredentials != Boolean.TRUE) {
return ALL;
if (!ObjectUtils.isEmpty(this.allowedOrigins)) {
if (this.allowedOrigins.contains(ALL)) {
if (this.allowCredentials != Boolean.TRUE) {
return ALL;
}
else {
return requestOrigin;
}
}
else {
return requestOrigin;
for (String allowedOrigin : this.allowedOrigins) {
if (requestOrigin.equalsIgnoreCase(allowedOrigin)) {
return requestOrigin;
}
}
}
for (String allowedOrigin : this.allowedOrigins) {
if (requestOrigin.equalsIgnoreCase(allowedOrigin)) {
return requestOrigin;
if (!ObjectUtils.isEmpty(this.allowedOriginsPatterns)) {
for (Pattern allowedOriginsPattern : this.allowedOriginsPatterns) {
if (allowedOriginsPattern.pattern().equals(ALL_PATTERN)) {
if (this.allowCredentials != Boolean.TRUE) {
return ALL;
}
else {
return requestOrigin;
}
}
else if (allowedOriginsPattern.matcher(requestOrigin).matches()) {
return requestOrigin;
}
}
}

84
spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -50,6 +50,8 @@ public class CorsConfigurationTests { @@ -50,6 +50,8 @@ public class CorsConfigurationTests {
assertThat(config.getAllowCredentials()).isNull();
config.setMaxAge((Long) null);
assertThat(config.getMaxAge()).isNull();
config.setAllowedOriginsPatterns(null);
assertThat(config.getAllowedOriginsPatterns()).isNull();
}
@Test
@ -68,6 +70,8 @@ public class CorsConfigurationTests { @@ -68,6 +70,8 @@ public class CorsConfigurationTests {
assertThat((boolean) config.getAllowCredentials()).isTrue();
config.setMaxAge(123L);
assertThat(config.getMaxAge()).isEqualTo(new Long(123));
config.addAllowedOriginPattern(".*\\.example\\.com");
assertThat(config.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.example\\.com"));
}
@Test
@ -101,6 +105,7 @@ public class CorsConfigurationTests { @@ -101,6 +105,7 @@ public class CorsConfigurationTests {
config.addAllowedMethod(HttpMethod.GET.name());
config.setMaxAge(123L);
config.setAllowCredentials(true);
config.setAllowedOriginsPatterns(Arrays.asList(".*\\.example\\.com"));
CorsConfiguration other = new CorsConfiguration();
config = config.combine(other);
assertThat(config.getAllowedOrigins()).isEqualTo(Arrays.asList("*"));
@ -109,6 +114,7 @@ public class CorsConfigurationTests { @@ -109,6 +114,7 @@ public class CorsConfigurationTests {
assertThat(config.getAllowedMethods()).isEqualTo(Arrays.asList(HttpMethod.GET.name()));
assertThat(config.getMaxAge()).isEqualTo(new Long(123));
assertThat((boolean) config.getAllowCredentials()).isTrue();
assertThat(config.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.example\\.com"));
}
@Test // SPR-15772
@ -142,25 +148,60 @@ public class CorsConfigurationTests { @@ -142,25 +148,60 @@ public class CorsConfigurationTests {
HttpMethod.POST.name()));
}
@Test
public void combinePatternWithDefaultPermitValues() {
CorsConfiguration config = new CorsConfiguration().applyPermitDefaultValues();
CorsConfiguration other = new CorsConfiguration();
other.addAllowedOriginPattern(".*\\.com");
CorsConfiguration combinedConfig = other.combine(config);
assertThat(combinedConfig.getAllowedOrigins()).isNull();
assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.com"));
combinedConfig = config.combine(other);
assertThat(combinedConfig.getAllowedOrigins()).isNull();
assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.com"));
}
@Test
public void combinePatternWithDefaultPermitValuesAndCustomOrigin() {
CorsConfiguration config = new CorsConfiguration().applyPermitDefaultValues();
config.setAllowedOrigins(Arrays.asList("https://domain.com"));
CorsConfiguration other = new CorsConfiguration();
other.addAllowedOriginPattern(".*\\.com");
CorsConfiguration combinedConfig = other.combine(config);
assertThat(combinedConfig.getAllowedOrigins()).isEqualTo(Arrays.asList("https://domain.com"));
assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.com"));
combinedConfig = config.combine(other);
assertThat(combinedConfig.getAllowedOrigins()).isEqualTo(Arrays.asList("https://domain.com"));
assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.com"));
}
@Test
public void combineWithAsteriskWildCard() {
CorsConfiguration config = new CorsConfiguration();
config.addAllowedOrigin("*");
config.addAllowedHeader("*");
config.addAllowedMethod("*");
config.addAllowedOriginPattern(".*");
CorsConfiguration other = new CorsConfiguration();
other.addAllowedOrigin("https://domain.com");
other.addAllowedHeader("header1");
other.addExposedHeader("header2");
other.addAllowedOriginPattern(".*\\.company\\.com");
other.addAllowedMethod(HttpMethod.PUT.name());
CorsConfiguration combinedConfig = config.combine(other);
assertThat(combinedConfig.getAllowedOrigins()).isEqualTo(Arrays.asList("*"));
assertThat(combinedConfig.getAllowedHeaders()).isEqualTo(Arrays.asList("*"));
assertThat(combinedConfig.getAllowedMethods()).isEqualTo(Arrays.asList("*"));
assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*"));
combinedConfig = other.combine(config);
assertThat(combinedConfig.getAllowedOrigins()).isEqualTo(Arrays.asList("*"));
assertThat(combinedConfig.getAllowedHeaders()).isEqualTo(Arrays.asList("*"));
assertThat(combinedConfig.getAllowedMethods()).isEqualTo(Arrays.asList("*"));
assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*"));
}
@Test // SPR-14792
@ -174,16 +215,20 @@ public class CorsConfigurationTests { @@ -174,16 +215,20 @@ public class CorsConfigurationTests {
config.addExposedHeader("header4");
config.addAllowedMethod(HttpMethod.GET.name());
config.addAllowedMethod(HttpMethod.PUT.name());
config.addAllowedOriginPattern(".*\\.domain1\\.com");
config.addAllowedOriginPattern(".*\\.domain2\\.com");
CorsConfiguration other = new CorsConfiguration();
other.addAllowedOrigin("https://domain1.com");
other.addAllowedHeader("header1");
other.addExposedHeader("header3");
other.addAllowedMethod(HttpMethod.GET.name());
other.addAllowedOriginPattern(".*\\.domain1\\.com");
CorsConfiguration combinedConfig = config.combine(other);
assertThat(combinedConfig.getAllowedOrigins()).isEqualTo(Arrays.asList("https://domain1.com", "https://domain2.com"));
assertThat(combinedConfig.getAllowedHeaders()).isEqualTo(Arrays.asList("header1", "header2"));
assertThat(combinedConfig.getExposedHeaders()).isEqualTo(Arrays.asList("header3", "header4"));
assertThat(combinedConfig.getAllowedMethods()).isEqualTo(Arrays.asList(HttpMethod.GET.name(), HttpMethod.PUT.name()));
assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.domain1\\.com", ".*\\.domain2\\.com"));
}
@Test
@ -195,6 +240,7 @@ public class CorsConfigurationTests { @@ -195,6 +240,7 @@ public class CorsConfigurationTests {
config.addAllowedMethod(HttpMethod.GET.name());
config.setMaxAge(123L);
config.setAllowCredentials(true);
config.addAllowedOriginPattern(".*\\.domain1\\.com");
CorsConfiguration other = new CorsConfiguration();
other.addAllowedOrigin("https://domain2.com");
other.addAllowedHeader("header2");
@ -202,6 +248,7 @@ public class CorsConfigurationTests { @@ -202,6 +248,7 @@ public class CorsConfigurationTests {
other.addAllowedMethod(HttpMethod.PUT.name());
other.setMaxAge(456L);
other.setAllowCredentials(false);
other.addAllowedOriginPattern(".*\\.domain2\\.com");
config = config.combine(other);
assertThat(config.getAllowedOrigins()).isEqualTo(Arrays.asList("https://domain1.com", "https://domain2.com"));
assertThat(config.getAllowedHeaders()).isEqualTo(Arrays.asList("header1", "header2"));
@ -209,6 +256,7 @@ public class CorsConfigurationTests { @@ -209,6 +256,7 @@ public class CorsConfigurationTests {
assertThat(config.getAllowedMethods()).isEqualTo(Arrays.asList(HttpMethod.GET.name(), HttpMethod.PUT.name()));
assertThat(config.getMaxAge()).isEqualTo(new Long(456));
assertThat((boolean) config.getAllowCredentials()).isFalse();
assertThat(config.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.domain1\\.com", ".*\\.domain2\\.com"));
}
@Test
@ -237,6 +285,32 @@ public class CorsConfigurationTests { @@ -237,6 +285,32 @@ public class CorsConfigurationTests {
assertThat(config.checkOrigin("https://domain.com")).isNull();
}
@Test
public void checkOriginPatternAllowed() {
CorsConfiguration config = new CorsConfiguration();
config.setAllowedOriginsPatterns(Arrays.asList(".*"));
assertThat(config.checkOrigin("https://domain.com")).isEqualTo("*");
config.setAllowCredentials(true);
assertThat(config.checkOrigin("https://domain.com")).isEqualTo("https://domain.com");
config.setAllowedOriginsPatterns(Arrays.asList(".*\\.domain\\.com"));
assertThat(config.checkOrigin("https://example.domain.com")).isEqualTo("https://example.domain.com");
config.setAllowCredentials(false);
assertThat(config.checkOrigin("https://example.domain.com")).isEqualTo("https://example.domain.com");
}
@Test
public void checkOriginPatternNotAllowed() {
CorsConfiguration config = new CorsConfiguration();
assertThat(config.checkOrigin(null)).isNull();
assertThat(config.checkOrigin("https://domain.com")).isNull();
config.addAllowedOriginPattern(".*");
assertThat(config.checkOrigin(null)).isNull();
config.setAllowedOriginsPatterns(Arrays.asList(".*\\.domain1\\.com"));
assertThat(config.checkOrigin("https://domain2.com")).isNull();
config.setAllowedOriginsPatterns(new ArrayList<>());
assertThat(config.checkOrigin("https://domain.com")).isNull();
}
@Test
public void checkMethodAllowed() {
CorsConfiguration config = new CorsConfiguration();
@ -291,4 +365,12 @@ public class CorsConfigurationTests { @@ -291,4 +365,12 @@ public class CorsConfigurationTests {
assertThat(config.getAllowedMethods()).isEqualTo(Arrays.asList("GET", "HEAD", "POST", "PATCH"));
}
@Test
public void permitDefaultDoesntSetOriginWhenPatternPresent() {
CorsConfiguration config = new CorsConfiguration();
config.addAllowedOriginPattern(".*\\.com");
config = config.applyPermitDefaultValues();
assertThat(config.getAllowedOrigins()).isNull();
assertThat(config.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.com"));
}
}

3
spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java

@ -315,6 +315,9 @@ public class RequestMappingHandlerMapping extends RequestMappingInfoHandlerMappi @@ -315,6 +315,9 @@ public class RequestMappingHandlerMapping extends RequestMappingInfoHandlerMappi
for (String origin : annotation.origins()) {
config.addAllowedOrigin(resolveCorsAnnotationValue(origin));
}
for (String originsPattern : annotation.originsPatterns()) {
config.addAllowedOriginPattern(resolveCorsAnnotationValue(originsPattern));
}
for (RequestMethod method : annotation.methods()) {
config.addAllowedMethod(method.name());
}

15
spring-webflux/src/test/java/org/springframework/web/reactive/handler/CorsUrlHandlerMappingTests.java

@ -110,6 +110,21 @@ public class CorsUrlHandlerMappingTests { @@ -110,6 +110,21 @@ public class CorsUrlHandlerMappingTests {
assertThat(exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)).isEqualTo("*");
}
@Test
public void actualRequestWithGlobalPatternCorsConfig() throws Exception {
CorsConfiguration mappedConfig = new CorsConfiguration();
mappedConfig.addAllowedOriginPattern(".*\\.domain2.com");
this.handlerMapping.setCorsConfigurations(Collections.singletonMap("/welcome.html", mappedConfig));
String origin = "https://example.domain2.com";
ServerWebExchange exchange = createExchange(HttpMethod.GET, "/welcome.html", origin);
Object actual = this.handlerMapping.getHandler(exchange).block();
assertThat(actual).isNotNull();
assertThat(actual).isSameAs(this.welcomeController);
assertThat(exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)).isEqualTo("https://example.domain2.com");
}
@Test
public void preFlightRequestWithGlobalCorsConfig() throws Exception {
CorsConfiguration mappedConfig = new CorsConfiguration();

33
spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/CrossOriginAnnotationIntegrationTests.java

@ -68,6 +68,7 @@ class CrossOriginAnnotationIntegrationTests extends AbstractRequestMappingIntegr @@ -68,6 +68,7 @@ class CrossOriginAnnotationIntegrationTests extends AbstractRequestMappingIntegr
context.register(WebConfig.class);
Properties props = new Properties();
props.setProperty("myOrigin", "https://site1.com");
props.setProperty("myOriginPattern", ".*\\.com");
context.getEnvironment().getPropertySources().addFirst(new PropertiesPropertySource("ps", props));
context.register(PropertySourcesPlaceholderConfigurer.class);
context.refresh();
@ -206,6 +207,26 @@ class CrossOriginAnnotationIntegrationTests extends AbstractRequestMappingIntegr @@ -206,6 +207,26 @@ class CrossOriginAnnotationIntegrationTests extends AbstractRequestMappingIntegr
assertThat(entity.getBody()).isEqualTo("placeholder");
}
@ParameterizedHttpServerTest
void customOriginPatternDefinedViaValueAttribute(HttpServer httpServer) throws Exception {
startServer(httpServer);
ResponseEntity<String> entity = performGet("/origin-pattern-value-attribute", this.headers, String.class);
assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.OK);
assertThat(entity.getHeaders().getAccessControlAllowOrigin()).isEqualTo("https://site1.com");
assertThat(entity.getBody()).isEqualTo("pattern-value-attribute");
}
@ParameterizedHttpServerTest
void customOriginPatternDefinedViaPlaceholder(HttpServer httpServer) throws Exception {
startServer(httpServer);
ResponseEntity<String> entity = performGet("/origin-pattern-placeholder", this.headers, String.class);
assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.OK);
assertThat(entity.getHeaders().getAccessControlAllowOrigin()).isEqualTo("https://site1.com");
assertThat(entity.getBody()).isEqualTo("pattern-placeholder");
}
@ParameterizedHttpServerTest
void classLevel(HttpServer httpServer) throws Exception {
startServer(httpServer);
@ -335,6 +356,18 @@ class CrossOriginAnnotationIntegrationTests extends AbstractRequestMappingIntegr @@ -335,6 +356,18 @@ class CrossOriginAnnotationIntegrationTests extends AbstractRequestMappingIntegr
public String customOriginDefinedViaPlaceholder() {
return "placeholder";
}
@CrossOrigin(originsPatterns = ".*\\.com")
@GetMapping("/origin-pattern-value-attribute")
public String customOriginPatternDefinedViaValueAttribute() {
return "pattern-value-attribute";
}
@CrossOrigin(originsPatterns = "${myOriginPattern}")
@GetMapping("/origin-pattern-placeholder")
public String customOriginPatternDefinedViaPlaceholder() {
return "pattern-placeholder";
}
}

3
spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMapping.java

@ -451,6 +451,9 @@ public class RequestMappingHandlerMapping extends RequestMappingInfoHandlerMappi @@ -451,6 +451,9 @@ public class RequestMappingHandlerMapping extends RequestMappingInfoHandlerMappi
for (String origin : annotation.origins()) {
config.addAllowedOrigin(resolveCorsAnnotationValue(origin));
}
for (String originPatter : annotation.originsPatterns()) {
config.addAllowedOriginPattern(resolveCorsAnnotationValue(originPatter));
}
for (RequestMethod method : annotation.methods()) {
config.addAllowedMethod(method.name());
}

14
spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java

@ -121,8 +121,20 @@ class CorsAbstractHandlerMappingTests { @@ -121,8 +121,20 @@ class CorsAbstractHandlerMappingTests {
}
@PathPatternsParameterizedTest
void preflightRequestWithMappedCorsConfig(TestHandlerMapping mapping) throws Exception {
void actualRequestWithMappedPatternCorsConfiguration(TestHandlerMapping mapping) throws Exception {
CorsConfiguration config = new CorsConfiguration();
config.addAllowedOriginPattern(".*\\.domain2\\.com");
mapping.setCorsConfigurations(Collections.singletonMap("/foo", config));
MockHttpServletRequest request = getCorsRequest("/foo");
HandlerExecutionChain chain = mapping.getHandler(request);
assertThat(chain).isNotNull();
assertThat(chain.getHandler()).isInstanceOf(SimpleHandler.class);
assertThat(mapping.getRequiredCorsConfig().getAllowedOriginsPatterns()).containsExactly(".*\\.domain2\\.com");
}
@PathPatternsParameterizedTest
void preflightRequestWithMappedCorsConfig(TestHandlerMapping mapping) throws Exception {
CorsConfiguration config = new CorsConfiguration();
config.addAllowedOrigin("*");
mapping.setCorsConfigurations(Collections.singletonMap("/foo", config));

35
spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java

@ -72,6 +72,7 @@ class CrossOriginTests { @@ -72,6 +72,7 @@ class CrossOriginTests {
StaticWebApplicationContext wac = new StaticWebApplicationContext();
Properties props = new Properties();
props.setProperty("myOrigin", "https://example.com");
props.setProperty("myDomainPattern", ".*\\.example\\.com");
wac.getEnvironment().getPropertySources().addFirst(new PropertiesPropertySource("ps", props));
wac.registerSingleton("ppc", PropertySourcesPlaceholderConfigurer.class);
wac.refresh();
@ -197,6 +198,30 @@ class CrossOriginTests { @@ -197,6 +198,30 @@ class CrossOriginTests {
assertThat(config.getAllowCredentials()).isNull();
}
@PathPatternsParameterizedTest
public void customOriginPatternViaValueAttribute(TestRequestMappingInfoHandlerMapping mapping) throws Exception {
mapping.registerHandler(new MethodLevelController());
this.request.setRequestURI("/customOriginPattern");
HandlerExecutionChain chain = mapping.getHandler(request);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertThat(config).isNotNull();
assertThat(config.getAllowedOrigins()).isNull();
assertThat(config.getAllowedOriginsPatterns()).isEqualTo(Collections.singletonList(".*\\.example\\.com"));
assertThat(config.getAllowCredentials()).isNull();
}
@PathPatternsParameterizedTest
public void customOriginPatternViaPlaceholder(TestRequestMappingInfoHandlerMapping mapping) throws Exception {
mapping.registerHandler(new MethodLevelController());
this.request.setRequestURI("/customOriginPatternPlaceholder");
HandlerExecutionChain chain = mapping.getHandler(request);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertThat(config).isNotNull();
assertThat(config.getAllowedOrigins()).isNull();
assertThat(config.getAllowedOriginsPatterns()).isEqualTo(Collections.singletonList(".*\\.example\\.com"));
assertThat(config.getAllowCredentials()).isNull();
}
@PathPatternsParameterizedTest
void bogusAllowCredentialsValue(TestRequestMappingInfoHandlerMapping mapping) {
assertThatIllegalStateException().isThrownBy(() ->
@ -407,6 +432,16 @@ class CrossOriginTests { @@ -407,6 +432,16 @@ class CrossOriginTests {
@RequestMapping("/someOrigin")
public void customOriginDefinedViaPlaceholder() {
}
@CrossOrigin(originsPatterns = ".*\\.example\\.com")
@RequestMapping("/customOriginPattern")
public void customOriginPatternDefinedViaValueAttribute() {
}
@CrossOrigin(originsPatterns = "${myDomainPattern}")
@RequestMapping("/customOriginPatternPlaceholder")
public void customOriginPatternDefinedViaPlaceholder() {
}
}

Loading…
Cancel
Save