From 61569003b5db89da7f841a98489a012fdee50012 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Tue, 10 Mar 2020 17:10:24 +0100 Subject: [PATCH] CORS support in HTTP header predicate This commit introduces CORS support for the HeadersPredicate in WebMvc.fn and WebFlux.fn. Closes gh-24564 --- .../reactive/function/server/RequestPredicates.java | 7 ++++++- .../function/server/RequestPredicatesTests.java | 12 ++++++++++++ .../web/servlet/function/RequestPredicates.java | 7 ++++++- .../web/servlet/function/RequestPredicatesTests.java | 10 ++++++++++ 4 files changed, 34 insertions(+), 2 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java index 2569ce2d1a..9a08d79ce2 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java @@ -550,7 +550,12 @@ public abstract class RequestPredicates { @Override public boolean test(ServerRequest request) { - return this.headersPredicate.test(request.headers()); + if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) { + return true; + } + else { + return this.headersPredicate.test(request.headers()); + } } @Override diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java index 4282baec27..3a8536cdf7 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java @@ -198,6 +198,18 @@ public class RequestPredicatesTests { assertThat(predicate.test(request)).isFalse(); } + @Test + public void headersCors() { + RequestPredicate predicate = RequestPredicates.headers(headers -> false); + MockServerHttpRequest mockRequest = MockServerHttpRequest.options("https://example.com") + .header("Origin", "https://example.com") + .header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "PUT") + .build(); + ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); + assertThat(predicate.test(request)).isTrue(); + } + + @Test public void contentType() { MediaType json = MediaType.APPLICATION_JSON; diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java index 667c211932..1fe1e7f31e 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RequestPredicates.java @@ -543,7 +543,12 @@ public abstract class RequestPredicates { @Override public boolean test(ServerRequest request) { - return this.headersPredicate.test(request.headers()); + if (CorsUtils.isPreFlightRequest(request.servletRequest())) { + return true; + } + else { + return this.headersPredicate.test(request.headers()); + } } @Override diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java index ab07961024..1297bb61ed 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RequestPredicatesTests.java @@ -197,6 +197,16 @@ public class RequestPredicatesTests { assertThat(predicate.test(request)).isFalse(); } + @Test + public void headersCors() { + RequestPredicate predicate = RequestPredicates.headers(headers -> false); + MockHttpServletRequest servletRequest = new MockHttpServletRequest("OPTIONS", "https://example.com"); + servletRequest.addHeader("Origin", "https://example.com"); + servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "PUT"); + ServerRequest request = new DefaultServerRequest(servletRequest, emptyList()); + assertThat(predicate.test(request)).isTrue(); + } + @Test public void contentType() { MediaType json = MediaType.APPLICATION_JSON;