diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsWebFilter.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsWebFilter.java new file mode 100644 index 0000000000..f397a4e57f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsWebFilter.java @@ -0,0 +1,74 @@ +package org.springframework.web.cors.reactive; + +import reactor.core.publisher.Mono; + +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.util.Assert; +import org.springframework.web.cors.*; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; + + +/** + * {@link WebFilter} that handles CORS preflight requests and intercepts + * CORS simple and actual requests thanks to a {@link CorsProcessor} implementation + * ({@link DefaultCorsProcessor} by default) in order to add the relevant CORS + * response headers (like {@code Access-Control-Allow-Origin}) using the provided + * {@link CorsConfigurationSource} (for example an {@link UrlBasedCorsConfigurationSource} + * instance. + * + *

This is an alternative to Spring WebFlux Java config CORS configuration, + * mostly useful for applications using the functional API. + * + * @author Sebastien Deleuze + * @since 5.0 + * @see CORS W3C recommendation + */ +public class CorsWebFilter implements WebFilter { + + private final CorsConfigurationSource configSource; + + private final CorsProcessor processor; + + + /** + * Constructor accepting a {@link CorsConfigurationSource} used by the filter + * to find the {@link CorsConfiguration} to use for each incoming request. + * @see UrlBasedCorsConfigurationSource + */ + public CorsWebFilter(CorsConfigurationSource configSource) { + this(configSource, new DefaultCorsProcessor()); + } + + /** + * Constructor accepting a {@link CorsConfigurationSource} used by the filter + * to find the {@link CorsConfiguration} to use for each incoming request and a + * custom {@link CorsProcessor} to use to apply the matched + * {@link CorsConfiguration} for a request. + * @see UrlBasedCorsConfigurationSource + */ + public CorsWebFilter(CorsConfigurationSource configSource, CorsProcessor processor) { + Assert.notNull(configSource, "CorsConfigurationSource must not be null"); + Assert.notNull(processor, "CorsProcessor must not be null"); + this.configSource = configSource; + this.processor = processor; + } + + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + ServerHttpRequest request = exchange.getRequest(); + if (CorsUtils.isCorsRequest(request)) { + CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(exchange); + if (corsConfiguration != null) { + boolean isValid = this.processor.process(corsConfiguration, exchange); + if (!isValid || CorsUtils.isPreFlightRequest(request)) { + return Mono.empty(); + } + } + } + return chain.filter(exchange); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsWebFilterTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsWebFilterTests.java new file mode 100644 index 0000000000..76bac021a4 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsWebFilterTests.java @@ -0,0 +1,124 @@ +package org.springframework.web.cors.reactive; + + +import java.io.IOException; +import java.util.Arrays; + +import javax.servlet.ServletException; + +import org.junit.Before; +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpMethod; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerWebExchange; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.WebFilterChain; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.springframework.http.HttpHeaders.*; + +/** + * Unit tests for {@link CorsWebFilter}. + * @author Sebastien Deleuze + */ +public class CorsWebFilterTests { + + private CorsWebFilter filter; + + private final CorsConfiguration config = new CorsConfiguration(); + + @Before + public void setup() throws Exception { + config.setAllowedOrigins(Arrays.asList("http://domain1.com", "http://domain2.com")); + config.setAllowedMethods(Arrays.asList("GET", "POST")); + config.setAllowedHeaders(Arrays.asList("header1", "header2")); + config.setExposedHeaders(Arrays.asList("header3", "header4")); + config.setMaxAge(123L); + config.setAllowCredentials(false); + filter = new CorsWebFilter(r -> config); + } + + @Test + public void validActualRequest() { + + MockServerHttpRequest request = MockServerHttpRequest + .get("http://domain1.com/test.html") + .header(HOST, "domain1.com") + .header(ORIGIN, "http://domain2.com") + .header("header2", "foo") + .build(); + MockServerWebExchange exchange = new MockServerWebExchange(request); + + WebFilterChain filterChain = (filterExchange) -> { + try { + assertEquals("http://domain2.com", filterExchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("header3, header4", filterExchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_EXPOSE_HEADERS)); + } catch (AssertionError ex) { + return Mono.error(ex); + } + return Mono.empty(); + + }; + filter.filter(exchange, filterChain); + } + + @Test + public void invalidActualRequest() throws ServletException, IOException { + + MockServerHttpRequest request = MockServerHttpRequest + .delete("http://domain1.com/test.html") + .header(HOST, "domain1.com") + .header(ORIGIN, "http://domain2.com") + .header("header2", "foo") + .build(); + MockServerWebExchange exchange = new MockServerWebExchange(request); + + WebFilterChain filterChain = (filterExchange) -> Mono.error(new AssertionError("Invalid requests must not be forwarded to the filter chain")); + filter.filter(exchange, filterChain); + + assertNull(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + } + + @Test + public void validPreFlightRequest() throws ServletException, IOException { + + MockServerHttpRequest request = MockServerHttpRequest + .options("http://domain1.com/test.html") + .header(HOST, "domain1.com") + .header(ORIGIN, "http://domain2.com") + .header(ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.GET.name()) + .header(ACCESS_CONTROL_REQUEST_HEADERS, "header1, header2") + .build(); + MockServerWebExchange exchange = new MockServerWebExchange(request); + + WebFilterChain filterChain = (filterExchange) -> Mono.error(new AssertionError("Preflight requests must not be forwarded to the filter chain")); + filter.filter(exchange, filterChain); + + assertEquals("http://domain2.com", exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("header1, header2", exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS)); + assertEquals("header3, header4", exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_EXPOSE_HEADERS)); + assertEquals(123L, Long.parseLong(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_MAX_AGE))); + } + + @Test + public void invalidPreFlightRequest() throws ServletException, IOException { + + MockServerHttpRequest request = MockServerHttpRequest + .options("http://domain1.com/test.html") + .header(HOST, "domain1.com") + .header(ORIGIN, "http://domain2.com") + .header(ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.DELETE.name()) + .header(ACCESS_CONTROL_REQUEST_HEADERS, "header1, header2") + .build(); + MockServerWebExchange exchange = new MockServerWebExchange(request); + + WebFilterChain filterChain = (filterExchange) -> Mono.error(new AssertionError("Preflight requests must not be forwarded to the filter chain")); + filter.filter(exchange, filterChain); + + assertNull(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + } + +}