Browse Source
To do so, adds 2 components: * ApplicationLister that updates CorsProperties based on route metadata Also: * Renamed CorsTests to CorsGlobalTests * Add test CorsPerRouteTests * Add docs: split current single section into 2: global & route configpull/2757/head
8 changed files with 313 additions and 4 deletions
@ -0,0 +1,133 @@
@@ -0,0 +1,133 @@
|
||||
/* |
||||
* Copyright 2013-2020 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package org.springframework.cloud.gateway.filter.cors; |
||||
|
||||
import java.util.ArrayList; |
||||
import java.util.Arrays; |
||||
import java.util.HashMap; |
||||
import java.util.List; |
||||
import java.util.Map; |
||||
import java.util.Optional; |
||||
|
||||
import org.springframework.cloud.gateway.config.GlobalCorsProperties; |
||||
import org.springframework.cloud.gateway.event.RefreshRoutesEvent; |
||||
import org.springframework.cloud.gateway.handler.RoutePredicateHandlerMapping; |
||||
import org.springframework.cloud.gateway.route.RouteDefinition; |
||||
import org.springframework.cloud.gateway.route.RouteDefinitionLocator; |
||||
import org.springframework.context.ApplicationListener; |
||||
import org.springframework.web.cors.CorsConfiguration; |
||||
|
||||
/** |
||||
* @author Fredrich Ombico |
||||
* @author Abel Salgado Romero |
||||
*/ |
||||
public class CorsGatewayFilterApplicationListener implements ApplicationListener<RefreshRoutesEvent> { |
||||
|
||||
private final GlobalCorsProperties globalCorsProperties; |
||||
|
||||
private final RoutePredicateHandlerMapping routePredicateHandlerMapping; |
||||
|
||||
private final RouteDefinitionLocator routeDefinitionLocator; |
||||
|
||||
private static final String PATH_PREDICATE_NAME = "Path"; |
||||
|
||||
private static final String METADATA_KEY = "cors"; |
||||
|
||||
private static final String ALL_PATHS = "/**"; |
||||
|
||||
public CorsGatewayFilterApplicationListener(GlobalCorsProperties globalCorsProperties, |
||||
RoutePredicateHandlerMapping routePredicateHandlerMapping, RouteDefinitionLocator routeDefinitionLocator) { |
||||
this.globalCorsProperties = globalCorsProperties; |
||||
this.routePredicateHandlerMapping = routePredicateHandlerMapping; |
||||
this.routeDefinitionLocator = routeDefinitionLocator; |
||||
} |
||||
|
||||
@Override |
||||
public void onApplicationEvent(RefreshRoutesEvent event) { |
||||
routeDefinitionLocator.getRouteDefinitions().collectList().subscribe(routeDefinitions -> { |
||||
// pre-populate with pre-existing global cors configurations to combine with.
|
||||
var corsConfigurations = new HashMap<>(globalCorsProperties.getCorsConfigurations()); |
||||
|
||||
routeDefinitions.forEach(routeDefinition -> { |
||||
var pathPredicate = getPathPredicate(routeDefinition); |
||||
var corsConfiguration = getCorsConfiguration(routeDefinition); |
||||
corsConfiguration.ifPresent(configuration -> corsConfigurations.put(pathPredicate, configuration)); |
||||
}); |
||||
|
||||
routePredicateHandlerMapping.setCorsConfigurations(corsConfigurations); |
||||
}); |
||||
} |
||||
|
||||
private String getPathPredicate(RouteDefinition routeDefinition) { |
||||
return routeDefinition.getPredicates().stream() |
||||
.filter(predicate -> PATH_PREDICATE_NAME.equals(predicate.getName())).findFirst() |
||||
.flatMap(predicate -> predicate.getArgs().values().stream().findFirst()).orElse(ALL_PATHS); |
||||
} |
||||
|
||||
private Optional<CorsConfiguration> getCorsConfiguration(RouteDefinition routeDefinition) { |
||||
Map<String, Object> corsMetadata = (Map<String, Object>) routeDefinition.getMetadata().get(METADATA_KEY); |
||||
if (corsMetadata != null) { |
||||
final CorsConfiguration corsConfiguration = new CorsConfiguration(); |
||||
|
||||
findValue(corsMetadata, "allowCredential") |
||||
.ifPresent(value -> corsConfiguration.setAllowCredentials((Boolean) value)); |
||||
findValue(corsMetadata, "allowedHeaders") |
||||
.ifPresent(value -> corsConfiguration.setAllowedHeaders(asList(value))); |
||||
findValue(corsMetadata, "allowedMethods") |
||||
.ifPresent(value -> corsConfiguration.setAllowedMethods(asList(value))); |
||||
findValue(corsMetadata, "allowedOriginPatterns") |
||||
.ifPresent(value -> corsConfiguration.setAllowedOriginPatterns(asList(value))); |
||||
findValue(corsMetadata, "allowedOrigins") |
||||
.ifPresent(value -> corsConfiguration.setAllowedOrigins(asList(value))); |
||||
findValue(corsMetadata, "exposedHeaders") |
||||
.ifPresent(value -> corsConfiguration.setExposedHeaders(asList(value))); |
||||
findValue(corsMetadata, "maxAge") |
||||
.ifPresent(value -> corsConfiguration.setMaxAge(asLong(value))); |
||||
|
||||
return Optional.of(corsConfiguration); |
||||
} |
||||
|
||||
return Optional.empty(); |
||||
} |
||||
|
||||
private Optional<Object> findValue(Map<String, Object> metadata, String key) { |
||||
Object value = metadata.get(key); |
||||
return Optional.ofNullable(value); |
||||
} |
||||
|
||||
private List<String> asList(Object value) { |
||||
if (value instanceof String) { |
||||
return Arrays.asList((String) value); |
||||
} |
||||
if (value instanceof Map) { |
||||
return new ArrayList<>(((Map<?, String>) value).values()); |
||||
} |
||||
else { |
||||
return (List<String>) value; |
||||
} |
||||
} |
||||
|
||||
private Long asLong(Object value) { |
||||
if (value instanceof Integer) { |
||||
return ((Integer) value).longValue(); |
||||
} |
||||
else { |
||||
return (Long) value; |
||||
} |
||||
} |
||||
|
||||
} |
@ -0,0 +1,91 @@
@@ -0,0 +1,91 @@
|
||||
/* |
||||
* Copyright 2013-2020 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package org.springframework.cloud.gateway.cors; |
||||
|
||||
import java.util.Map; |
||||
|
||||
import org.junit.jupiter.api.Test; |
||||
|
||||
import org.springframework.boot.SpringBootConfiguration; |
||||
import org.springframework.boot.autoconfigure.EnableAutoConfiguration; |
||||
import org.springframework.boot.test.context.SpringBootTest; |
||||
import org.springframework.cloud.gateway.test.BaseWebClientTests; |
||||
import org.springframework.context.annotation.Import; |
||||
import org.springframework.http.HttpHeaders; |
||||
import org.springframework.http.HttpMethod; |
||||
import org.springframework.http.HttpStatus; |
||||
import org.springframework.test.annotation.DirtiesContext; |
||||
import org.springframework.test.context.ActiveProfiles; |
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat; |
||||
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT; |
||||
import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN; |
||||
import static org.springframework.http.HttpHeaders.ACCESS_CONTROL_MAX_AGE; |
||||
|
||||
@SpringBootTest(webEnvironment = RANDOM_PORT) |
||||
@DirtiesContext |
||||
@ActiveProfiles(profiles = "cors-per-route-config") |
||||
public class CorsPerRouteTests extends BaseWebClientTests { |
||||
|
||||
@Test |
||||
public void testPreFlightCorsRequest() { |
||||
testClient.options().uri("/abc").header("Origin", "domain.com").header("Access-Control-Request-Method", "GET") |
||||
.exchange().expectBody(Map.class).consumeWith(result -> { |
||||
assertThat(result.getResponseBody()).isNull(); |
||||
assertThat(result.getStatus()).isEqualTo(HttpStatus.OK); |
||||
|
||||
HttpHeaders responseHeaders = result.getResponseHeaders(); |
||||
assertThat(responseHeaders.getAccessControlAllowOrigin()) |
||||
.as(missingHeader(ACCESS_CONTROL_ALLOW_ORIGIN)).isEqualTo("*"); |
||||
assertThat(responseHeaders.getAccessControlAllowMethods()) |
||||
.as(missingHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)) |
||||
.containsExactlyInAnyOrder(HttpMethod.GET, HttpMethod.POST); |
||||
assertThat(responseHeaders.getAccessControlMaxAge()).as(missingHeader(ACCESS_CONTROL_MAX_AGE)) |
||||
.isEqualTo(30L); |
||||
}); |
||||
} |
||||
|
||||
@Test |
||||
public void testPreFlightForbiddenCorsRequest() { |
||||
testClient.get().uri("/cors").header("Origin", "domain.com").header("Access-Control-Request-Method", "GET") |
||||
.exchange().expectBody(Map.class).consumeWith(result -> { |
||||
assertThat(result.getResponseBody()).isNull(); |
||||
assertThat(result.getStatus()).isEqualTo(HttpStatus.FORBIDDEN); |
||||
}); |
||||
} |
||||
|
||||
@Test |
||||
public void testCorsValidatedRequest() { |
||||
testClient.get().uri("/cors/status/201").header("Origin", "https://test.com").exchange() |
||||
.expectBody(String.class).consumeWith(result -> { |
||||
assertThat(result.getResponseBody()).endsWith("201"); |
||||
assertThat(result.getStatus()).isEqualTo(HttpStatus.CREATED); |
||||
}); |
||||
} |
||||
|
||||
private String missingHeader(String accessControlAllowOrigin) { |
||||
return "Missing header value in response: " + accessControlAllowOrigin; |
||||
} |
||||
|
||||
@EnableAutoConfiguration |
||||
@SpringBootConfiguration |
||||
@Import(DefaultTestConfig.class) |
||||
public static class TestConfig { |
||||
|
||||
} |
||||
|
||||
} |
@ -0,0 +1,10 @@
@@ -0,0 +1,10 @@
|
||||
spring: |
||||
cloud: |
||||
gateway: |
||||
globalcors: |
||||
cors-configurations: |
||||
'[/**]': |
||||
maxAge: 10 |
||||
allowedOrigins: "*" |
||||
allowedMethods: |
||||
- GET |
@ -0,0 +1,27 @@
@@ -0,0 +1,27 @@
|
||||
spring: |
||||
cloud: |
||||
gateway: |
||||
routes: |
||||
- id: cors_preflight_test |
||||
uri: ${test.uri} |
||||
predicates: |
||||
- Path=/abc/** |
||||
metadata: |
||||
cors: |
||||
allowedOrigins: '*' |
||||
allowedMethods: [ GET, POST ] |
||||
allowedHeaders: '*' |
||||
maxAge: 30 |
||||
- id: cors_test |
||||
uri: ${test.uri} |
||||
predicates: |
||||
- Path=/cors/** |
||||
filters: |
||||
- StripPrefix=1 |
||||
metadata: |
||||
cors: |
||||
allowedOrigins: https://test.com |
||||
allowedMethods: |
||||
- GET |
||||
- PUT |
||||
allowedHeaders: '*' |
Loading…
Reference in new issue