Browse Source

Adds BeforeFilterFunctions.addRequestHeadersIfNotPresent()

See gh-2949
pull/3006/head
sgibb 1 year ago
parent
commit
cf39b20ae2
No known key found for this signature in database
GPG Key ID: 7788A47380690861
  1. 79
      spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/common/KeyValues.java
  2. 5
      spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/common/MvcUtils.java
  3. 26
      spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/BeforeFilterFunctions.java
  4. 10
      spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/FilterFunctions.java
  5. 27
      spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/ServerMvcIntegrationTests.java

79
spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/common/KeyValues.java

@ -0,0 +1,79 @@
/*
* Copyright 2013-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.gateway.server.mvc.common;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.springframework.core.style.ToStringCreator;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
public class KeyValues {
private List<KeyValue> keyValues = new ArrayList<>();
public List<KeyValue> getKeyValues() {
return keyValues;
}
public void setKeyValues(List<KeyValue> keyValues) {
this.keyValues = keyValues;
}
public static KeyValues valueOf(String s) {
String[] tokens = StringUtils.tokenizeToStringArray(s, ",", true, true);
List<KeyValue> parsedKeyValues = Arrays.stream(tokens).map(KeyValue::valueOf).toList();
KeyValues keyValues = new KeyValues();
keyValues.setKeyValues(parsedKeyValues);
return keyValues;
}
public static class KeyValue {
private final String key;
private final String value;
public KeyValue(String key, String value) {
this.key = key;
this.value = value;
}
public String getKey() {
return key;
}
public String getValue() {
return value;
}
@Override
public String toString() {
return new ToStringCreator(this).append("name", key).append("value", value).toString();
}
public static KeyValue valueOf(String s) {
String[] tokens = StringUtils.tokenizeToStringArray(s, ":", true, true);
Assert.isTrue(tokens.length == 2, () -> "String must be two tokens delimited by colon, but was " + s);
return new KeyValue(tokens[0], tokens[1]);
}
}
}

5
spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/common/MvcUtils.java

@ -18,6 +18,7 @@ package org.springframework.cloud.gateway.server.mvc.common;
import java.net.URI; import java.net.URI;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -67,6 +68,10 @@ public abstract class MvcUtils {
return UriComponentsBuilder.fromPath(template).build().expand(variables).getPath(); return UriComponentsBuilder.fromPath(template).build().expand(variables).getPath();
} }
public static List<String> expandMultiple(ServerRequest request, Collection<String> templates) {
return templates.stream().map(value -> MvcUtils.expand(request, value)).toList();
}
public static String[] expandMultiple(ServerRequest request, String... templates) { public static String[] expandMultiple(ServerRequest request, String... templates) {
List<String> expanded = Arrays.stream(templates).map(value -> MvcUtils.expand(request, value)).toList(); List<String> expanded = Arrays.stream(templates).map(value -> MvcUtils.expand(request, value)).toList();
return expanded.toArray(new String[0]); return expanded.toArray(new String[0]);

26
spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/BeforeFilterFunctions.java

@ -18,6 +18,7 @@ package org.springframework.cloud.gateway.server.mvc.filter;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -27,6 +28,7 @@ import java.util.regex.Pattern;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.gateway.server.mvc.common.KeyValues.KeyValue;
import org.springframework.cloud.gateway.server.mvc.common.MvcUtils; import org.springframework.cloud.gateway.server.mvc.common.MvcUtils;
import org.springframework.core.log.LogMessage; import org.springframework.core.log.LogMessage;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
@ -64,6 +66,30 @@ public abstract class BeforeFilterFunctions {
}; };
} }
public static Function<ServerRequest, ServerRequest> addRequestHeadersIfNotPresent(String... values) {
List<KeyValue> keyValues = Arrays.stream(values).map(KeyValue::valueOf).toList();
return addRequestHeadersIfNotPresent(keyValues);
}
public static Function<ServerRequest, ServerRequest> addRequestHeadersIfNotPresent(List<KeyValue> keyValues) {
HttpHeaders newHeaders = new HttpHeaders();
keyValues.forEach(keyValue -> newHeaders.add(keyValue.getKey(), keyValue.getValue()));
return request -> {
ServerRequest.Builder requestBuilder = ServerRequest.from(request);
newHeaders.forEach((newHeaderName, newHeaderValues) -> {
boolean headerIsMissingOrBlank = request.headers().asHttpHeaders().getOrEmpty(newHeaderName).stream()
.allMatch(h -> !StringUtils.hasText(h));
if (headerIsMissingOrBlank) {
requestBuilder.headers(httpHeaders -> {
List<String> expandedValues = MvcUtils.expandMultiple(request, newHeaderValues);
httpHeaders.addAll(newHeaderName, expandedValues);
});
}
});
return requestBuilder.build();
};
}
public static Function<ServerRequest, ServerRequest> addRequestParameter(String name, String... values) { public static Function<ServerRequest, ServerRequest> addRequestParameter(String name, String... values) {
return request -> { return request -> {
String[] expandedValues = MvcUtils.expandMultiple(request, values); String[] expandedValues = MvcUtils.expandMultiple(request, values);

10
spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/FilterFunctions.java

@ -22,6 +22,7 @@ import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import org.springframework.cloud.gateway.server.mvc.common.HttpStatusHolder; import org.springframework.cloud.gateway.server.mvc.common.HttpStatusHolder;
import org.springframework.cloud.gateway.server.mvc.common.KeyValues;
import org.springframework.cloud.gateway.server.mvc.common.Shortcut; import org.springframework.cloud.gateway.server.mvc.common.Shortcut;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatusCode; import org.springframework.http.HttpStatusCode;
@ -40,6 +41,15 @@ public interface FilterFunctions {
return ofRequestProcessor(BeforeFilterFunctions.addRequestHeader(name, values)); return ofRequestProcessor(BeforeFilterFunctions.addRequestHeader(name, values));
} }
static HandlerFilterFunction<ServerResponse, ServerResponse> addRequestHeadersIfNotPresent(String values) {
return addRequestHeadersIfNotPresent(KeyValues.valueOf(values));
}
@Shortcut
static HandlerFilterFunction<ServerResponse, ServerResponse> addRequestHeadersIfNotPresent(KeyValues keyValues) {
return ofRequestProcessor(BeforeFilterFunctions.addRequestHeadersIfNotPresent(keyValues.getKeyValues()));
}
@Shortcut @Shortcut
static HandlerFilterFunction<ServerResponse, ServerResponse> addRequestParameter(String name, String... values) { static HandlerFilterFunction<ServerResponse, ServerResponse> addRequestParameter(String name, String... values) {
return ofRequestProcessor(BeforeFilterFunctions.addRequestParameter(name, values)); return ofRequestProcessor(BeforeFilterFunctions.addRequestParameter(name, values));

27
spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/ServerMvcIntegrationTests.java

@ -84,6 +84,7 @@ import static org.springframework.cloud.gateway.server.mvc.filter.BeforeFilterFu
import static org.springframework.cloud.gateway.server.mvc.filter.Bucket4jFilterFunctions.rateLimit; import static org.springframework.cloud.gateway.server.mvc.filter.Bucket4jFilterFunctions.rateLimit;
import static org.springframework.cloud.gateway.server.mvc.filter.CircuitBreakerFilterFunctions.circuitBreaker; import static org.springframework.cloud.gateway.server.mvc.filter.CircuitBreakerFilterFunctions.circuitBreaker;
import static org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions.addRequestHeader; import static org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions.addRequestHeader;
import static org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions.addRequestHeadersIfNotPresent;
import static org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions.addRequestParameter; import static org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions.addRequestParameter;
import static org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions.prefixPath; import static org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions.prefixPath;
import static org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions.redirectTo; import static org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions.redirectTo;
@ -475,6 +476,20 @@ public class ServerMvcIntegrationTests {
}); });
} }
@Test
public void addRequestHeadersIfNotPresentWorks() {
restClient.get().uri("/headers").header("Host", "www.addrequestheadersifnotpresent.org")
.header("X-Request-Beta", "Value1").exchange().expectStatus().isOk().expectBody(Map.class)
.consumeWith(res -> {
Map<String, Object> headers = getMap(res.getResponseBody(), "headers");
// this asserts that Value2 was not added
assertThat(headers).containsEntry("X-Request-Beta", "Value1");
assertThat(headers).containsKey("X-Request-Acme");
List<String> values = (List<String>) headers.get("X-Request-Acme");
assertThat(values).hasSize(4).containsOnly("ValueX", "ValueY", "ValueZ", "www");
});
}
@SpringBootConfiguration @SpringBootConfiguration
@EnableAutoConfiguration @EnableAutoConfiguration
@LoadBalancerClient(name = "httpbin", configuration = TestLoadBalancerConfig.Httpbin.class) @LoadBalancerClient(name = "httpbin", configuration = TestLoadBalancerConfig.Httpbin.class)
@ -854,6 +869,18 @@ public class ServerMvcIntegrationTests {
// @formatter:on // @formatter:on
} }
@Bean
public RouterFunction<ServerResponse> gatewayRouterFunctionsAddRequestHeadersIfNotPresent() {
// @formatter:off
return route("testaddrequestheadersifnotpresent")
.route(GET("/headers").and(host("{sub}.addrequestheadersifnotpresent.org")), http())
.filter(new HttpbinUriResolver())
// normally use BeforeFilterFunctions version, but wanted to test parsing for config
.filter(addRequestHeadersIfNotPresent("X-Request-Acme:ValueX, X-Request-Acme:ValueY,X-Request-Acme:ValueZ, X-Request-Acme:{sub},X-Request-Beta:Value2"))
.build();
// @formatter:on
}
} }
@RestController @RestController

Loading…
Cancel
Save