diff --git a/spring-cloud-gateway-webflux/src/main/java/org/springframework/cloud/gateway/webflux/ProxyExchange.java b/spring-cloud-gateway-webflux/src/main/java/org/springframework/cloud/gateway/webflux/ProxyExchange.java index c5c06364b..a115c9d31 100644 --- a/spring-cloud-gateway-webflux/src/main/java/org/springframework/cloud/gateway/webflux/ProxyExchange.java +++ b/spring-cloud-gateway-webflux/src/main/java/org/springframework/cloud/gateway/webflux/ProxyExchange.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.HashSet; import java.util.Set; import java.util.function.Function; +import java.util.stream.Collectors; import org.reactivestreams.Publisher; @@ -135,6 +136,7 @@ public class ProxyExchange { this.bindingContext = bindingContext; this.responseType = type; this.rest = rest; + this.sensitive = DEFAULT_SENSITIVE; } /** @@ -319,7 +321,7 @@ public class ProxyExchange { Type type = this.responseType; RequestBodySpec builder = rest.method(requestEntity.getMethod()) .uri(requestEntity.getUrl()) - .headers(headers -> headers.addAll(requestEntity.getHeaders())); + .headers(headers -> addHeaders(headers, requestEntity.getHeaders())); Mono result; if (requestEntity.getBody() instanceof Publisher) { @SuppressWarnings("unchecked") @@ -333,29 +335,31 @@ public class ProxyExchange { else { if (hasBody) { result = builder.headers( - headers -> headers.addAll(exchange.getRequest().getHeaders())) + headers -> addHeaders(headers, exchange.getRequest().getHeaders())) .body(exchange.getRequest().getBody(), DataBuffer.class) .exchange(); } else { result = builder.headers( - headers -> headers.addAll(exchange.getRequest().getHeaders())) + headers -> addHeaders(headers, exchange.getRequest().getHeaders())) .exchange(); } } return result.flatMap(response -> response.toEntity(ParameterizedTypeReference.forType(type))); } + private void addHeaders(HttpHeaders headers, HttpHeaders toAdd) { + Set filteredHeaders = filterHeaderKeys(toAdd); + filteredHeaders.stream().forEach(header -> headers.addAll(header, toAdd.get(header))); + } + + private Set filterHeaderKeys(HttpHeaders headers) { + return headers.keySet().stream().filter(header -> !sensitive.contains(header.toLowerCase())).collect(Collectors.toSet()); + } + private BodyBuilder headers(BodyBuilder builder) { - Set sensitive = this.sensitive; - if (sensitive == null) { - sensitive = DEFAULT_SENSITIVE; - } proxy(); - for (String name : headers.keySet()) { - if (sensitive.contains(name.toLowerCase())) { - continue; - } + for (String name : filterHeaderKeys(headers)) { builder.header(name, headers.get(name).toArray(new String[0])); } return builder; diff --git a/spring-cloud-gateway-webflux/src/test/java/org/springframework/cloud/gateway/webflux/ProductionConfigurationTests.java b/spring-cloud-gateway-webflux/src/test/java/org/springframework/cloud/gateway/webflux/ProductionConfigurationTests.java index 9d1040a8f..f250e9939 100644 --- a/spring-cloud-gateway-webflux/src/test/java/org/springframework/cloud/gateway/webflux/ProductionConfigurationTests.java +++ b/spring-cloud-gateway-webflux/src/test/java/org/springframework/cloud/gateway/webflux/ProductionConfigurationTests.java @@ -55,6 +55,8 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -175,6 +177,16 @@ public class ProductionConfigurationTests { .isEqualTo("host=localhost;foobar"); } + @Test + public void headers() throws Exception { + Map> headers = rest.exchange(RequestEntity.get(rest.getRestTemplate().getUriTemplateHandler() + .expand("/proxy/headers")).header("foo", "bar").header("abc", "xyz").build(), Map.class).getBody(); + assertTrue(!headers.containsKey("foo")); + assertTrue(!headers.containsKey("hello")); + assertEquals("hello", headers.get("bar")); + assertEquals("123", headers.get("abc")); + } + @SpringBootApplication static class TestApplication { @@ -269,6 +281,17 @@ public class ProductionConfigurationTests { .body(response.getBody().iterator().next())); } + @GetMapping("/proxy/headers") + public Mono>> headers(ProxyExchange> proxy) { + proxy.sensitive("foo"); + proxy.sensitive("hello"); + proxy.header("bar", "hello"); + proxy.header("abc", "123"); + proxy.header("hello", "world"); + return proxy.uri(home.toString() + "/headers").get(); + + } + private ResponseEntity first(ResponseEntity> response) { return ResponseEntity.status(response.getStatusCode()) .headers(response.getHeaders()) @@ -308,6 +331,11 @@ public class ProductionConfigurationTests { return Arrays.asList(new Bar(custom + foos.iterator().next().getName())); } + @GetMapping("/headers") + public Map headers(@RequestHeader HttpHeaders headers) { + return headers.toSingleValueMap(); + } + } @JsonIgnoreProperties(ignoreUnknown = true)