From 8632118e8d39230ccee42a882b463dd6232d547c Mon Sep 17 00:00:00 2001 From: Ruslan Akhundov Date: Mon, 4 May 2020 09:55:30 +0100 Subject: [PATCH] CorsConfiguration now supports pattern based origins. Closes gh-24763 --- .../web/bind/annotation/CrossOrigin.java | 20 ++- .../web/cors/CorsConfiguration.java | 120 +++++++++++++++--- .../web/cors/CorsConfigurationTests.java | 84 +++++++++++- .../RequestMappingHandlerMapping.java | 3 + .../handler/CorsUrlHandlerMappingTests.java | 15 +++ ...CrossOriginAnnotationIntegrationTests.java | 33 +++++ .../RequestMappingHandlerMapping.java | 3 + .../CorsAbstractHandlerMappingTests.java | 14 +- .../method/annotation/CrossOriginTests.java | 35 +++++ 9 files changed, 308 insertions(+), 19 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/bind/annotation/CrossOrigin.java b/spring-web/src/main/java/org/springframework/web/bind/annotation/CrossOrigin.java index 7fb1b88c4d..e9c9e1f8c7 100644 --- a/spring-web/src/main/java/org/springframework/web/bind/annotation/CrossOrigin.java +++ b/spring-web/src/main/java/org/springframework/web/bind/annotation/CrossOrigin.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,6 +45,7 @@ import org.springframework.web.cors.CorsConfiguration; * @author Russell Allen * @author Sebastien Deleuze * @author Sam Brannen + * @author Ruslan Akhundov * @since 4.2 */ @Target({ElementType.TYPE, ElementType.METHOD}) @@ -93,6 +94,23 @@ public @interface CrossOrigin { @AliasFor("value") String[] origins() default {}; + /** + * The list of allowed origins patterns that be specific origins, e.g. + * {@code ".*\.domain1\.com"}, or {@code ".*"} for matching all origins. + *

A matched origin is listed in the {@code Access-Control-Allow-Origin} + * response header of preflight actual CORS requests. + *

By default all origins are allowed. + *

Note: CORS checks use values from "Forwarded" + * (RFC 7239), + * "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" headers, + * if present, in order to reflect the client-originated address. + * Consider using the {@code ForwardedHeaderFilter} in order to choose from a + * central place whether to extract and use, or to discard such headers. + * See the Spring Framework reference for more on this filter. + * @see #value + */ + String[] originsPatterns() default {}; + /** * The list of request headers that are permitted in actual requests, * possibly {@code "*"} to allow all headers. diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java b/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java index 884a13add2..6b12c259f8 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java +++ b/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; import java.util.Set; +import java.util.regex.Pattern; import java.util.stream.Collectors; import org.springframework.http.HttpMethod; @@ -45,6 +46,7 @@ import org.springframework.util.StringUtils; * @author Rossen Stoyanchev * @author Juergen Hoeller * @author Sam Brannen + * @author Ruslan Akhundov * @since 4.2 * @see CORS spec */ @@ -52,6 +54,8 @@ public class CorsConfiguration { /** Wildcard representing all origins, methods, or headers. */ public static final String ALL = "*"; + /** Wildcard representing pattern that matches all origins. */ + public static final String ALL_PATTERN = ".*"; private static final List DEFAULT_METHODS = Collections.unmodifiableList( Arrays.asList(HttpMethod.GET, HttpMethod.HEAD)); @@ -62,10 +66,19 @@ public class CorsConfiguration { private static final List DEFAULT_PERMIT_ALL = Collections.unmodifiableList( Collections.singletonList(ALL)); + private static final List DEFAULT_PERMIT_ALL_PATTERN_STR = Collections.unmodifiableList( + Collections.singletonList(ALL_PATTERN)); + + private static final List DEFAULT_PERMIT_ALL_PATTERN = Collections.unmodifiableList( + Collections.singletonList(Pattern.compile(ALL_PATTERN))); + @Nullable private List allowedOrigins; + @Nullable + private List allowedOriginsPatterns; + @Nullable private List allowedMethods; @@ -99,6 +112,7 @@ public class CorsConfiguration { */ public CorsConfiguration(CorsConfiguration other) { this.allowedOrigins = other.allowedOrigins; + this.allowedOriginsPatterns = other.allowedOriginsPatterns; this.allowedMethods = other.allowedMethods; this.resolvedMethods = other.resolvedMethods; this.allowedHeaders = other.allowedHeaders; @@ -140,6 +154,54 @@ public class CorsConfiguration { this.allowedOrigins.add(origin); } + /** + * Set the origins patterns to allow, e.g. {@code "*.com"}. + *

By default this is not set. + */ + public CorsConfiguration setAllowedOriginsPatterns(@Nullable List allowedOriginsPatterns) { + if (allowedOriginsPatterns == null) { + this.allowedOriginsPatterns = null; + } + else { + this.allowedOriginsPatterns = new ArrayList<>(allowedOriginsPatterns.size()); + for (String pattern : allowedOriginsPatterns) { + this.allowedOriginsPatterns.add(Pattern.compile(pattern)); + } + } + + return this; + } + + /** + * Return the configured origins patterns to allow, or {@code null} if none. + * + * @see #addAllowedOriginPattern(String) + * @see #setAllowedOriginsPatterns(List) + */ + @Nullable + public List getAllowedOriginsPatterns() { + if (this.allowedOriginsPatterns == null) { + return null; + } + if (this.allowedOriginsPatterns == DEFAULT_PERMIT_ALL_PATTERN) { + return DEFAULT_PERMIT_ALL_PATTERN_STR; + } + return this.allowedOriginsPatterns.stream().map(Pattern::toString).collect(Collectors.toList()); + } + + /** + * Add an origin pattern to allow. + */ + public void addAllowedOriginPattern(String originPattern) { + if (this.allowedOriginsPatterns == null) { + this.allowedOriginsPatterns = new ArrayList<>(4); + } + else if (this.allowedOriginsPatterns == DEFAULT_PERMIT_ALL_PATTERN) { + setAllowedOriginsPatterns(DEFAULT_PERMIT_ALL_PATTERN_STR); + } + this.allowedOriginsPatterns.add(Pattern.compile(originPattern)); + } + /** * Set the HTTP methods to allow, e.g. {@code "GET"}, {@code "POST"}, * {@code "PUT"}, etc. @@ -351,7 +413,7 @@ public class CorsConfiguration { * */ public CorsConfiguration applyPermitDefaultValues() { - if (this.allowedOrigins == null) { + if (this.allowedOrigins == null && this.allowedOriginsPatterns == null) { this.allowedOrigins = DEFAULT_PERMIT_ALL; } if (this.allowedMethods == null) { @@ -392,7 +454,14 @@ public class CorsConfiguration { return this; } CorsConfiguration config = new CorsConfiguration(this); - config.setAllowedOrigins(combine(getAllowedOrigins(), other.getAllowedOrigins())); + List combinedOrigins = combine(getAllowedOrigins(), other.getAllowedOrigins()); + List combinedOriginsPatterns = combine(getAllowedOriginsPatterns(), other.getAllowedOriginsPatterns()); + if (combinedOrigins == DEFAULT_PERMIT_ALL && combinedOriginsPatterns != DEFAULT_PERMIT_ALL_PATTERN_STR + && !CollectionUtils.isEmpty(combinedOriginsPatterns)) { + combinedOrigins = null; + } + config.setAllowedOrigins(combinedOrigins); + config.setAllowedOriginsPatterns(combinedOriginsPatterns); config.setAllowedMethods(combine(getAllowedMethods(), other.getAllowedMethods())); config.setAllowedHeaders(combine(getAllowedHeaders(), other.getAllowedHeaders())); config.setExposedHeaders(combine(getExposedHeaders(), other.getExposedHeaders())); @@ -414,15 +483,20 @@ public class CorsConfiguration { if (source == null) { return other; } - if (source == DEFAULT_PERMIT_ALL || source == DEFAULT_PERMIT_METHODS) { + if (source == DEFAULT_PERMIT_ALL || source == DEFAULT_PERMIT_METHODS + || source == DEFAULT_PERMIT_ALL_PATTERN_STR) { return other; } - if (other == DEFAULT_PERMIT_ALL || other == DEFAULT_PERMIT_METHODS) { + if (other == DEFAULT_PERMIT_ALL || other == DEFAULT_PERMIT_METHODS + || other == DEFAULT_PERMIT_ALL_PATTERN_STR) { return source; } if (source.contains(ALL) || other.contains(ALL)) { return new ArrayList<>(Collections.singletonList(ALL)); } + if ( source.contains(ALL_PATTERN) || other.contains(ALL_PATTERN)) { + return new ArrayList<>(Collections.singletonList(ALL_PATTERN)); + } Set combined = new LinkedHashSet<>(source); combined.addAll(other); return new ArrayList<>(combined); @@ -439,21 +513,35 @@ public class CorsConfiguration { if (!StringUtils.hasText(requestOrigin)) { return null; } - if (ObjectUtils.isEmpty(this.allowedOrigins)) { - return null; - } - if (this.allowedOrigins.contains(ALL)) { - if (this.allowCredentials != Boolean.TRUE) { - return ALL; + if (!ObjectUtils.isEmpty(this.allowedOrigins)) { + if (this.allowedOrigins.contains(ALL)) { + if (this.allowCredentials != Boolean.TRUE) { + return ALL; + } + else { + return requestOrigin; + } } - else { - return requestOrigin; + for (String allowedOrigin : this.allowedOrigins) { + if (requestOrigin.equalsIgnoreCase(allowedOrigin)) { + return requestOrigin; + } } } - for (String allowedOrigin : this.allowedOrigins) { - if (requestOrigin.equalsIgnoreCase(allowedOrigin)) { - return requestOrigin; + if (!ObjectUtils.isEmpty(this.allowedOriginsPatterns)) { + for (Pattern allowedOriginsPattern : this.allowedOriginsPatterns) { + if (allowedOriginsPattern.pattern().equals(ALL_PATTERN)) { + if (this.allowCredentials != Boolean.TRUE) { + return ALL; + } + else { + return requestOrigin; + } + } + else if (allowedOriginsPattern.matcher(requestOrigin).matches()) { + return requestOrigin; + } } } diff --git a/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java b/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java index c18a456111..0d311805ce 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,6 +50,8 @@ public class CorsConfigurationTests { assertThat(config.getAllowCredentials()).isNull(); config.setMaxAge((Long) null); assertThat(config.getMaxAge()).isNull(); + config.setAllowedOriginsPatterns(null); + assertThat(config.getAllowedOriginsPatterns()).isNull(); } @Test @@ -68,6 +70,8 @@ public class CorsConfigurationTests { assertThat((boolean) config.getAllowCredentials()).isTrue(); config.setMaxAge(123L); assertThat(config.getMaxAge()).isEqualTo(new Long(123)); + config.addAllowedOriginPattern(".*\\.example\\.com"); + assertThat(config.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.example\\.com")); } @Test @@ -101,6 +105,7 @@ public class CorsConfigurationTests { config.addAllowedMethod(HttpMethod.GET.name()); config.setMaxAge(123L); config.setAllowCredentials(true); + config.setAllowedOriginsPatterns(Arrays.asList(".*\\.example\\.com")); CorsConfiguration other = new CorsConfiguration(); config = config.combine(other); assertThat(config.getAllowedOrigins()).isEqualTo(Arrays.asList("*")); @@ -109,6 +114,7 @@ public class CorsConfigurationTests { assertThat(config.getAllowedMethods()).isEqualTo(Arrays.asList(HttpMethod.GET.name())); assertThat(config.getMaxAge()).isEqualTo(new Long(123)); assertThat((boolean) config.getAllowCredentials()).isTrue(); + assertThat(config.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.example\\.com")); } @Test // SPR-15772 @@ -142,25 +148,60 @@ public class CorsConfigurationTests { HttpMethod.POST.name())); } + @Test + public void combinePatternWithDefaultPermitValues() { + CorsConfiguration config = new CorsConfiguration().applyPermitDefaultValues(); + CorsConfiguration other = new CorsConfiguration(); + other.addAllowedOriginPattern(".*\\.com"); + + CorsConfiguration combinedConfig = other.combine(config); + assertThat(combinedConfig.getAllowedOrigins()).isNull(); + assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.com")); + + combinedConfig = config.combine(other); + assertThat(combinedConfig.getAllowedOrigins()).isNull(); + assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.com")); + } + + @Test + public void combinePatternWithDefaultPermitValuesAndCustomOrigin() { + CorsConfiguration config = new CorsConfiguration().applyPermitDefaultValues(); + config.setAllowedOrigins(Arrays.asList("https://domain.com")); + CorsConfiguration other = new CorsConfiguration(); + other.addAllowedOriginPattern(".*\\.com"); + + CorsConfiguration combinedConfig = other.combine(config); + assertThat(combinedConfig.getAllowedOrigins()).isEqualTo(Arrays.asList("https://domain.com")); + assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.com")); + + combinedConfig = config.combine(other); + assertThat(combinedConfig.getAllowedOrigins()).isEqualTo(Arrays.asList("https://domain.com")); + assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.com")); + } + @Test public void combineWithAsteriskWildCard() { CorsConfiguration config = new CorsConfiguration(); config.addAllowedOrigin("*"); config.addAllowedHeader("*"); config.addAllowedMethod("*"); + config.addAllowedOriginPattern(".*"); CorsConfiguration other = new CorsConfiguration(); other.addAllowedOrigin("https://domain.com"); other.addAllowedHeader("header1"); other.addExposedHeader("header2"); + other.addAllowedOriginPattern(".*\\.company\\.com"); other.addAllowedMethod(HttpMethod.PUT.name()); CorsConfiguration combinedConfig = config.combine(other); assertThat(combinedConfig.getAllowedOrigins()).isEqualTo(Arrays.asList("*")); assertThat(combinedConfig.getAllowedHeaders()).isEqualTo(Arrays.asList("*")); assertThat(combinedConfig.getAllowedMethods()).isEqualTo(Arrays.asList("*")); + assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*")); combinedConfig = other.combine(config); assertThat(combinedConfig.getAllowedOrigins()).isEqualTo(Arrays.asList("*")); assertThat(combinedConfig.getAllowedHeaders()).isEqualTo(Arrays.asList("*")); assertThat(combinedConfig.getAllowedMethods()).isEqualTo(Arrays.asList("*")); + assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*")); } @Test // SPR-14792 @@ -174,16 +215,20 @@ public class CorsConfigurationTests { config.addExposedHeader("header4"); config.addAllowedMethod(HttpMethod.GET.name()); config.addAllowedMethod(HttpMethod.PUT.name()); + config.addAllowedOriginPattern(".*\\.domain1\\.com"); + config.addAllowedOriginPattern(".*\\.domain2\\.com"); CorsConfiguration other = new CorsConfiguration(); other.addAllowedOrigin("https://domain1.com"); other.addAllowedHeader("header1"); other.addExposedHeader("header3"); other.addAllowedMethod(HttpMethod.GET.name()); + other.addAllowedOriginPattern(".*\\.domain1\\.com"); CorsConfiguration combinedConfig = config.combine(other); assertThat(combinedConfig.getAllowedOrigins()).isEqualTo(Arrays.asList("https://domain1.com", "https://domain2.com")); assertThat(combinedConfig.getAllowedHeaders()).isEqualTo(Arrays.asList("header1", "header2")); assertThat(combinedConfig.getExposedHeaders()).isEqualTo(Arrays.asList("header3", "header4")); assertThat(combinedConfig.getAllowedMethods()).isEqualTo(Arrays.asList(HttpMethod.GET.name(), HttpMethod.PUT.name())); + assertThat(combinedConfig.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.domain1\\.com", ".*\\.domain2\\.com")); } @Test @@ -195,6 +240,7 @@ public class CorsConfigurationTests { config.addAllowedMethod(HttpMethod.GET.name()); config.setMaxAge(123L); config.setAllowCredentials(true); + config.addAllowedOriginPattern(".*\\.domain1\\.com"); CorsConfiguration other = new CorsConfiguration(); other.addAllowedOrigin("https://domain2.com"); other.addAllowedHeader("header2"); @@ -202,6 +248,7 @@ public class CorsConfigurationTests { other.addAllowedMethod(HttpMethod.PUT.name()); other.setMaxAge(456L); other.setAllowCredentials(false); + other.addAllowedOriginPattern(".*\\.domain2\\.com"); config = config.combine(other); assertThat(config.getAllowedOrigins()).isEqualTo(Arrays.asList("https://domain1.com", "https://domain2.com")); assertThat(config.getAllowedHeaders()).isEqualTo(Arrays.asList("header1", "header2")); @@ -209,6 +256,7 @@ public class CorsConfigurationTests { assertThat(config.getAllowedMethods()).isEqualTo(Arrays.asList(HttpMethod.GET.name(), HttpMethod.PUT.name())); assertThat(config.getMaxAge()).isEqualTo(new Long(456)); assertThat((boolean) config.getAllowCredentials()).isFalse(); + assertThat(config.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.domain1\\.com", ".*\\.domain2\\.com")); } @Test @@ -237,6 +285,32 @@ public class CorsConfigurationTests { assertThat(config.checkOrigin("https://domain.com")).isNull(); } + @Test + public void checkOriginPatternAllowed() { + CorsConfiguration config = new CorsConfiguration(); + config.setAllowedOriginsPatterns(Arrays.asList(".*")); + assertThat(config.checkOrigin("https://domain.com")).isEqualTo("*"); + config.setAllowCredentials(true); + assertThat(config.checkOrigin("https://domain.com")).isEqualTo("https://domain.com"); + config.setAllowedOriginsPatterns(Arrays.asList(".*\\.domain\\.com")); + assertThat(config.checkOrigin("https://example.domain.com")).isEqualTo("https://example.domain.com"); + config.setAllowCredentials(false); + assertThat(config.checkOrigin("https://example.domain.com")).isEqualTo("https://example.domain.com"); + } + + @Test + public void checkOriginPatternNotAllowed() { + CorsConfiguration config = new CorsConfiguration(); + assertThat(config.checkOrigin(null)).isNull(); + assertThat(config.checkOrigin("https://domain.com")).isNull(); + config.addAllowedOriginPattern(".*"); + assertThat(config.checkOrigin(null)).isNull(); + config.setAllowedOriginsPatterns(Arrays.asList(".*\\.domain1\\.com")); + assertThat(config.checkOrigin("https://domain2.com")).isNull(); + config.setAllowedOriginsPatterns(new ArrayList<>()); + assertThat(config.checkOrigin("https://domain.com")).isNull(); + } + @Test public void checkMethodAllowed() { CorsConfiguration config = new CorsConfiguration(); @@ -291,4 +365,12 @@ public class CorsConfigurationTests { assertThat(config.getAllowedMethods()).isEqualTo(Arrays.asList("GET", "HEAD", "POST", "PATCH")); } + @Test + public void permitDefaultDoesntSetOriginWhenPatternPresent() { + CorsConfiguration config = new CorsConfiguration(); + config.addAllowedOriginPattern(".*\\.com"); + config = config.applyPermitDefaultValues(); + assertThat(config.getAllowedOrigins()).isNull(); + assertThat(config.getAllowedOriginsPatterns()).isEqualTo(Arrays.asList(".*\\.com")); + } } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java index fa82f405f7..8915de5866 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java @@ -315,6 +315,9 @@ public class RequestMappingHandlerMapping extends RequestMappingInfoHandlerMappi for (String origin : annotation.origins()) { config.addAllowedOrigin(resolveCorsAnnotationValue(origin)); } + for (String originsPattern : annotation.originsPatterns()) { + config.addAllowedOriginPattern(resolveCorsAnnotationValue(originsPattern)); + } for (RequestMethod method : annotation.methods()) { config.addAllowedMethod(method.name()); } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/handler/CorsUrlHandlerMappingTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/handler/CorsUrlHandlerMappingTests.java index 6342332645..3c7b1b7d3f 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/handler/CorsUrlHandlerMappingTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/handler/CorsUrlHandlerMappingTests.java @@ -110,6 +110,21 @@ public class CorsUrlHandlerMappingTests { assertThat(exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)).isEqualTo("*"); } + @Test + public void actualRequestWithGlobalPatternCorsConfig() throws Exception { + CorsConfiguration mappedConfig = new CorsConfiguration(); + mappedConfig.addAllowedOriginPattern(".*\\.domain2.com"); + this.handlerMapping.setCorsConfigurations(Collections.singletonMap("/welcome.html", mappedConfig)); + + String origin = "https://example.domain2.com"; + ServerWebExchange exchange = createExchange(HttpMethod.GET, "/welcome.html", origin); + Object actual = this.handlerMapping.getHandler(exchange).block(); + + assertThat(actual).isNotNull(); + assertThat(actual).isSameAs(this.welcomeController); + assertThat(exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)).isEqualTo("https://example.domain2.com"); + } + @Test public void preFlightRequestWithGlobalCorsConfig() throws Exception { CorsConfiguration mappedConfig = new CorsConfiguration(); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/CrossOriginAnnotationIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/CrossOriginAnnotationIntegrationTests.java index b07e752b04..2da85ba408 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/CrossOriginAnnotationIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/CrossOriginAnnotationIntegrationTests.java @@ -68,6 +68,7 @@ class CrossOriginAnnotationIntegrationTests extends AbstractRequestMappingIntegr context.register(WebConfig.class); Properties props = new Properties(); props.setProperty("myOrigin", "https://site1.com"); + props.setProperty("myOriginPattern", ".*\\.com"); context.getEnvironment().getPropertySources().addFirst(new PropertiesPropertySource("ps", props)); context.register(PropertySourcesPlaceholderConfigurer.class); context.refresh(); @@ -206,6 +207,26 @@ class CrossOriginAnnotationIntegrationTests extends AbstractRequestMappingIntegr assertThat(entity.getBody()).isEqualTo("placeholder"); } + @ParameterizedHttpServerTest + void customOriginPatternDefinedViaValueAttribute(HttpServer httpServer) throws Exception { + startServer(httpServer); + + ResponseEntity entity = performGet("/origin-pattern-value-attribute", this.headers, String.class); + assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.OK); + assertThat(entity.getHeaders().getAccessControlAllowOrigin()).isEqualTo("https://site1.com"); + assertThat(entity.getBody()).isEqualTo("pattern-value-attribute"); + } + + @ParameterizedHttpServerTest + void customOriginPatternDefinedViaPlaceholder(HttpServer httpServer) throws Exception { + startServer(httpServer); + + ResponseEntity entity = performGet("/origin-pattern-placeholder", this.headers, String.class); + assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.OK); + assertThat(entity.getHeaders().getAccessControlAllowOrigin()).isEqualTo("https://site1.com"); + assertThat(entity.getBody()).isEqualTo("pattern-placeholder"); + } + @ParameterizedHttpServerTest void classLevel(HttpServer httpServer) throws Exception { startServer(httpServer); @@ -335,6 +356,18 @@ class CrossOriginAnnotationIntegrationTests extends AbstractRequestMappingIntegr public String customOriginDefinedViaPlaceholder() { return "placeholder"; } + + @CrossOrigin(originsPatterns = ".*\\.com") + @GetMapping("/origin-pattern-value-attribute") + public String customOriginPatternDefinedViaValueAttribute() { + return "pattern-value-attribute"; + } + + @CrossOrigin(originsPatterns = "${myOriginPattern}") + @GetMapping("/origin-pattern-placeholder") + public String customOriginPatternDefinedViaPlaceholder() { + return "pattern-placeholder"; + } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMapping.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMapping.java index 1ae8d0c6c8..4efaa8b84c 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMapping.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerMapping.java @@ -451,6 +451,9 @@ public class RequestMappingHandlerMapping extends RequestMappingInfoHandlerMappi for (String origin : annotation.origins()) { config.addAllowedOrigin(resolveCorsAnnotationValue(origin)); } + for (String originPatter : annotation.originsPatterns()) { + config.addAllowedOriginPattern(resolveCorsAnnotationValue(originPatter)); + } for (RequestMethod method : annotation.methods()) { config.addAllowedMethod(method.name()); } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java index a9aa4f37b9..6521bbb572 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java @@ -121,8 +121,20 @@ class CorsAbstractHandlerMappingTests { } @PathPatternsParameterizedTest - void preflightRequestWithMappedCorsConfig(TestHandlerMapping mapping) throws Exception { + void actualRequestWithMappedPatternCorsConfiguration(TestHandlerMapping mapping) throws Exception { + CorsConfiguration config = new CorsConfiguration(); + config.addAllowedOriginPattern(".*\\.domain2\\.com"); + mapping.setCorsConfigurations(Collections.singletonMap("/foo", config)); + MockHttpServletRequest request = getCorsRequest("/foo"); + HandlerExecutionChain chain = mapping.getHandler(request); + assertThat(chain).isNotNull(); + assertThat(chain.getHandler()).isInstanceOf(SimpleHandler.class); + assertThat(mapping.getRequiredCorsConfig().getAllowedOriginsPatterns()).containsExactly(".*\\.domain2\\.com"); + } + + @PathPatternsParameterizedTest + void preflightRequestWithMappedCorsConfig(TestHandlerMapping mapping) throws Exception { CorsConfiguration config = new CorsConfiguration(); config.addAllowedOrigin("*"); mapping.setCorsConfigurations(Collections.singletonMap("/foo", config)); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java index 4985d4c176..279a7d673c 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java @@ -72,6 +72,7 @@ class CrossOriginTests { StaticWebApplicationContext wac = new StaticWebApplicationContext(); Properties props = new Properties(); props.setProperty("myOrigin", "https://example.com"); + props.setProperty("myDomainPattern", ".*\\.example\\.com"); wac.getEnvironment().getPropertySources().addFirst(new PropertiesPropertySource("ps", props)); wac.registerSingleton("ppc", PropertySourcesPlaceholderConfigurer.class); wac.refresh(); @@ -197,6 +198,30 @@ class CrossOriginTests { assertThat(config.getAllowCredentials()).isNull(); } + @PathPatternsParameterizedTest + public void customOriginPatternViaValueAttribute(TestRequestMappingInfoHandlerMapping mapping) throws Exception { + mapping.registerHandler(new MethodLevelController()); + this.request.setRequestURI("/customOriginPattern"); + HandlerExecutionChain chain = mapping.getHandler(request); + CorsConfiguration config = getCorsConfiguration(chain, false); + assertThat(config).isNotNull(); + assertThat(config.getAllowedOrigins()).isNull(); + assertThat(config.getAllowedOriginsPatterns()).isEqualTo(Collections.singletonList(".*\\.example\\.com")); + assertThat(config.getAllowCredentials()).isNull(); + } + + @PathPatternsParameterizedTest + public void customOriginPatternViaPlaceholder(TestRequestMappingInfoHandlerMapping mapping) throws Exception { + mapping.registerHandler(new MethodLevelController()); + this.request.setRequestURI("/customOriginPatternPlaceholder"); + HandlerExecutionChain chain = mapping.getHandler(request); + CorsConfiguration config = getCorsConfiguration(chain, false); + assertThat(config).isNotNull(); + assertThat(config.getAllowedOrigins()).isNull(); + assertThat(config.getAllowedOriginsPatterns()).isEqualTo(Collections.singletonList(".*\\.example\\.com")); + assertThat(config.getAllowCredentials()).isNull(); + } + @PathPatternsParameterizedTest void bogusAllowCredentialsValue(TestRequestMappingInfoHandlerMapping mapping) { assertThatIllegalStateException().isThrownBy(() -> @@ -407,6 +432,16 @@ class CrossOriginTests { @RequestMapping("/someOrigin") public void customOriginDefinedViaPlaceholder() { } + + @CrossOrigin(originsPatterns = ".*\\.example\\.com") + @RequestMapping("/customOriginPattern") + public void customOriginPatternDefinedViaValueAttribute() { + } + + @CrossOrigin(originsPatterns = "${myDomainPattern}") + @RequestMapping("/customOriginPatternPlaceholder") + public void customOriginPatternDefinedViaPlaceholder() { + } }