diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java index 2569ce2d1a..9a08d79ce2 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java @@ -550,7 +550,12 @@ public abstract class RequestPredicates { @Override public boolean test(ServerRequest request) { - return this.headersPredicate.test(request.headers()); + if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) { + return true; + } + else { + return this.headersPredicate.test(request.headers()); + } } @Override diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java index 4282baec27..3a8536cdf7 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java @@ -198,6 +198,18 @@ public class RequestPredicatesTests { assertThat(predicate.test(request)).isFalse(); } + @Test + public void headersCors() { + RequestPredicate predicate = RequestPredicates.headers(headers -> false); + MockServerHttpRequest mockRequest = MockServerHttpRequest.options("https://example.com") + .header("Origin", "https://example.com") + .header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "PUT") + .build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); + assertThat(predicate.test(request)).isTrue(); + } + + @Test public void contentType() { MediaType json = MediaType.APPLICATION_JSON; diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java index 667c211932..1fe1e7f31e 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java @@ -543,7 +543,12 @@ public abstract class RequestPredicates { @Override public boolean test(ServerRequest request) { - return this.headersPredicate.test(request.headers()); + if (CorsUtils.isPreFlightRequest(request.servletRequest())) { + return true; + } + else { + return this.headersPredicate.test(request.headers()); + } } @Override diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java index ab07961024..1297bb61ed 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java @@ -197,6 +197,16 @@ public class RequestPredicatesTests { assertThat(predicate.test(request)).isFalse(); } + @Test + public void headersCors() { + RequestPredicate predicate = RequestPredicates.headers(headers -> false); + MockHttpServletRequest servletRequest = new MockHttpServletRequest("OPTIONS", "https://example.com"); + servletRequest.addHeader("Origin", "https://example.com"); + servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "PUT"); + ServerRequest request = new DefaultServerRequest(servletRequest, emptyList()); + assertThat(predicate.test(request)).isTrue(); + } + @Test public void contentType() { MediaType json = MediaType.APPLICATION_JSON;