Browse Source
This new WebFilter implementation is designed to allow initial CORS support when using WebFlux functional API. More high-level API may be introduced later. Issue: SPR-15567pull/1448/merge
2 changed files with 198 additions and 0 deletions
@ -0,0 +1,74 @@
@@ -0,0 +1,74 @@
|
||||
package org.springframework.web.cors.reactive; |
||||
|
||||
import reactor.core.publisher.Mono; |
||||
|
||||
import org.springframework.http.server.reactive.ServerHttpRequest; |
||||
import org.springframework.util.Assert; |
||||
import org.springframework.web.cors.*; |
||||
import org.springframework.web.server.ServerWebExchange; |
||||
import org.springframework.web.server.WebFilter; |
||||
import org.springframework.web.server.WebFilterChain; |
||||
|
||||
|
||||
/** |
||||
* {@link WebFilter} that handles CORS preflight requests and intercepts |
||||
* CORS simple and actual requests thanks to a {@link CorsProcessor} implementation |
||||
* ({@link DefaultCorsProcessor} by default) in order to add the relevant CORS |
||||
* response headers (like {@code Access-Control-Allow-Origin}) using the provided |
||||
* {@link CorsConfigurationSource} (for example an {@link UrlBasedCorsConfigurationSource} |
||||
* instance. |
||||
* |
||||
* <p>This is an alternative to Spring WebFlux Java config CORS configuration, |
||||
* mostly useful for applications using the functional API. |
||||
* |
||||
* @author Sebastien Deleuze |
||||
* @since 5.0 |
||||
* @see <a href="http://www.w3.org/TR/cors/">CORS W3C recommendation</a> |
||||
*/ |
||||
public class CorsWebFilter implements WebFilter { |
||||
|
||||
private final CorsConfigurationSource configSource; |
||||
|
||||
private final CorsProcessor processor; |
||||
|
||||
|
||||
/** |
||||
* Constructor accepting a {@link CorsConfigurationSource} used by the filter |
||||
* to find the {@link CorsConfiguration} to use for each incoming request. |
||||
* @see UrlBasedCorsConfigurationSource |
||||
*/ |
||||
public CorsWebFilter(CorsConfigurationSource configSource) { |
||||
this(configSource, new DefaultCorsProcessor()); |
||||
} |
||||
|
||||
/** |
||||
* Constructor accepting a {@link CorsConfigurationSource} used by the filter |
||||
* to find the {@link CorsConfiguration} to use for each incoming request and a |
||||
* custom {@link CorsProcessor} to use to apply the matched |
||||
* {@link CorsConfiguration} for a request. |
||||
* @see UrlBasedCorsConfigurationSource |
||||
*/ |
||||
public CorsWebFilter(CorsConfigurationSource configSource, CorsProcessor processor) { |
||||
Assert.notNull(configSource, "CorsConfigurationSource must not be null"); |
||||
Assert.notNull(processor, "CorsProcessor must not be null"); |
||||
this.configSource = configSource; |
||||
this.processor = processor; |
||||
} |
||||
|
||||
|
||||
@Override |
||||
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { |
||||
ServerHttpRequest request = exchange.getRequest(); |
||||
if (CorsUtils.isCorsRequest(request)) { |
||||
CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(exchange); |
||||
if (corsConfiguration != null) { |
||||
boolean isValid = this.processor.process(corsConfiguration, exchange); |
||||
if (!isValid || CorsUtils.isPreFlightRequest(request)) { |
||||
return Mono.empty(); |
||||
} |
||||
} |
||||
} |
||||
return chain.filter(exchange); |
||||
} |
||||
|
||||
} |
@ -0,0 +1,124 @@
@@ -0,0 +1,124 @@
|
||||
package org.springframework.web.cors.reactive; |
||||
|
||||
|
||||
import java.io.IOException; |
||||
import java.util.Arrays; |
||||
|
||||
import javax.servlet.ServletException; |
||||
|
||||
import org.junit.Before; |
||||
import org.junit.Test; |
||||
import reactor.core.publisher.Mono; |
||||
|
||||
import org.springframework.http.HttpMethod; |
||||
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; |
||||
import org.springframework.mock.http.server.reactive.test.MockServerWebExchange; |
||||
import org.springframework.web.cors.CorsConfiguration; |
||||
import org.springframework.web.server.WebFilterChain; |
||||
|
||||
import static org.junit.Assert.assertEquals; |
||||
import static org.junit.Assert.assertNull; |
||||
import static org.springframework.http.HttpHeaders.*; |
||||
|
||||
/** |
||||
* Unit tests for {@link CorsWebFilter}. |
||||
* @author Sebastien Deleuze |
||||
*/ |
||||
public class CorsWebFilterTests { |
||||
|
||||
private CorsWebFilter filter; |
||||
|
||||
private final CorsConfiguration config = new CorsConfiguration(); |
||||
|
||||
@Before |
||||
public void setup() throws Exception { |
||||
config.setAllowedOrigins(Arrays.asList("http://domain1.com", "http://domain2.com")); |
||||
config.setAllowedMethods(Arrays.asList("GET", "POST")); |
||||
config.setAllowedHeaders(Arrays.asList("header1", "header2")); |
||||
config.setExposedHeaders(Arrays.asList("header3", "header4")); |
||||
config.setMaxAge(123L); |
||||
config.setAllowCredentials(false); |
||||
filter = new CorsWebFilter(r -> config); |
||||
} |
||||
|
||||
@Test |
||||
public void validActualRequest() { |
||||
|
||||
MockServerHttpRequest request = MockServerHttpRequest |
||||
.get("http://domain1.com/test.html") |
||||
.header(HOST, "domain1.com") |
||||
.header(ORIGIN, "http://domain2.com") |
||||
.header("header2", "foo") |
||||
.build(); |
||||
MockServerWebExchange exchange = new MockServerWebExchange(request); |
||||
|
||||
WebFilterChain filterChain = (filterExchange) -> { |
||||
try { |
||||
assertEquals("http://domain2.com", filterExchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); |
||||
assertEquals("header3, header4", filterExchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_EXPOSE_HEADERS)); |
||||
} catch (AssertionError ex) { |
||||
return Mono.error(ex); |
||||
} |
||||
return Mono.empty(); |
||||
|
||||
}; |
||||
filter.filter(exchange, filterChain); |
||||
} |
||||
|
||||
@Test |
||||
public void invalidActualRequest() throws ServletException, IOException { |
||||
|
||||
MockServerHttpRequest request = MockServerHttpRequest |
||||
.delete("http://domain1.com/test.html") |
||||
.header(HOST, "domain1.com") |
||||
.header(ORIGIN, "http://domain2.com") |
||||
.header("header2", "foo") |
||||
.build(); |
||||
MockServerWebExchange exchange = new MockServerWebExchange(request); |
||||
|
||||
WebFilterChain filterChain = (filterExchange) -> Mono.error(new AssertionError("Invalid requests must not be forwarded to the filter chain")); |
||||
filter.filter(exchange, filterChain); |
||||
|
||||
assertNull(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); |
||||
} |
||||
|
||||
@Test |
||||
public void validPreFlightRequest() throws ServletException, IOException { |
||||
|
||||
MockServerHttpRequest request = MockServerHttpRequest |
||||
.options("http://domain1.com/test.html") |
||||
.header(HOST, "domain1.com") |
||||
.header(ORIGIN, "http://domain2.com") |
||||
.header(ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.GET.name()) |
||||
.header(ACCESS_CONTROL_REQUEST_HEADERS, "header1, header2") |
||||
.build(); |
||||
MockServerWebExchange exchange = new MockServerWebExchange(request); |
||||
|
||||
WebFilterChain filterChain = (filterExchange) -> Mono.error(new AssertionError("Preflight requests must not be forwarded to the filter chain")); |
||||
filter.filter(exchange, filterChain); |
||||
|
||||
assertEquals("http://domain2.com", exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); |
||||
assertEquals("header1, header2", exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_HEADERS)); |
||||
assertEquals("header3, header4", exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_EXPOSE_HEADERS)); |
||||
assertEquals(123L, Long.parseLong(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_MAX_AGE))); |
||||
} |
||||
|
||||
@Test |
||||
public void invalidPreFlightRequest() throws ServletException, IOException { |
||||
|
||||
MockServerHttpRequest request = MockServerHttpRequest |
||||
.options("http://domain1.com/test.html") |
||||
.header(HOST, "domain1.com") |
||||
.header(ORIGIN, "http://domain2.com") |
||||
.header(ACCESS_CONTROL_REQUEST_METHOD, HttpMethod.DELETE.name()) |
||||
.header(ACCESS_CONTROL_REQUEST_HEADERS, "header1, header2") |
||||
.build(); |
||||
MockServerWebExchange exchange = new MockServerWebExchange(request); |
||||
|
||||
WebFilterChain filterChain = (filterExchange) -> Mono.error(new AssertionError("Preflight requests must not be forwarded to the filter chain")); |
||||
filter.filter(exchange, filterChain); |
||||
|
||||
assertNull(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); |
||||
} |
||||
|
||||
} |
Loading…
Reference in new issue