From 0eb7206ab949cbf0e619593a3b16616818f680ae Mon Sep 17 00:00:00 2001 From: Spencer Gibb Date: Wed, 14 Nov 2018 18:41:46 -0500 Subject: [PATCH] Updates webflux ProxyExchange to overwrite incomming headers rather than add to them. Added the test to mvc ProxyExchange and it currently exhibits this behavior already and the test passes. see gh-643 --- .../mvc/ProductionConfigurationTests.java | 32 ++++++++++++++++++- .../cloud/gateway/webflux/ProxyExchange.java | 6 ++-- .../webflux/ProductionConfigurationTests.java | 18 ++++++----- 3 files changed, 45 insertions(+), 11 deletions(-) diff --git a/spring-cloud-gateway-mvc/src/test/java/org/springframework/cloud/gateway/mvc/ProductionConfigurationTests.java b/spring-cloud-gateway-mvc/src/test/java/org/springframework/cloud/gateway/mvc/ProductionConfigurationTests.java index e3054ecc1..0ab227c5c 100644 --- a/spring-cloud-gateway-mvc/src/test/java/org/springframework/cloud/gateway/mvc/ProductionConfigurationTests.java +++ b/spring-cloud-gateway-mvc/src/test/java/org/springframework/cloud/gateway/mvc/ProductionConfigurationTests.java @@ -44,6 +44,7 @@ import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.util.LinkedMultiValueMap; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; @@ -230,7 +231,21 @@ public class ProductionConfigurationTests { .isEqualTo("host=localhost;foobar"); } - @SpringBootApplication + @Test + @SuppressWarnings({"Duplicates", "unchecked"}) + public void headers() throws Exception { + Map> headers = rest.exchange(RequestEntity.get(rest.getRestTemplate().getUriTemplateHandler() + .expand("/proxy/headers")).header("foo", "bar").header("abc", "xyz").build(), Map.class).getBody(); + System.out.println(headers); + assertThat(headers).doesNotContainKey("foo") + .doesNotContainKey("hello") + .containsKeys("bar", "abc"); + + assertThat(headers.get("bar")).containsOnly("hello"); + assertThat(headers.get("abc")).containsOnly("123"); + } + + @SpringBootApplication static class TestApplication { @RestController @@ -359,6 +374,17 @@ public class ProductionConfigurationTests { proxy.forward(path); } + @GetMapping("/proxy/headers") + @SuppressWarnings("Duplicates") + public ResponseEntity>> headers(ProxyExchange>> proxy) { + proxy.sensitive("foo"); + proxy.sensitive("hello"); + proxy.header("bar", "hello"); + proxy.header("abc", "123"); + proxy.header("hello", "world"); + return proxy.uri(home.toString() + "/headers").get(); + } + private ResponseEntity first(ResponseEntity> response) { return ResponseEntity.status(response.getStatusCode()) .headers(response.getHeaders()) @@ -397,6 +423,10 @@ public class ProductionConfigurationTests { return Arrays.asList(new Bar(custom + foos.iterator().next().getName())); } + @GetMapping("/headers") + public Map> headers(@RequestHeader HttpHeaders headers) { + return new LinkedMultiValueMap<>(headers); + } } @JsonIgnoreProperties(ignoreUnknown = true) diff --git a/spring-cloud-gateway-webflux/src/main/java/org/springframework/cloud/gateway/webflux/ProxyExchange.java b/spring-cloud-gateway-webflux/src/main/java/org/springframework/cloud/gateway/webflux/ProxyExchange.java index 08f44228e..a59812ca6 100644 --- a/spring-cloud-gateway-webflux/src/main/java/org/springframework/cloud/gateway/webflux/ProxyExchange.java +++ b/spring-cloud-gateway-webflux/src/main/java/org/springframework/cloud/gateway/webflux/ProxyExchange.java @@ -350,8 +350,10 @@ public class ProxyExchange { } private void addHeaders(HttpHeaders headers, HttpHeaders toAdd) { - Set filteredHeaders = filterHeaderKeys(toAdd); - filteredHeaders.stream().forEach(header -> headers.addAll(header, toAdd.get(header))); + Set filteredKeys = filterHeaderKeys(toAdd); + filteredKeys.stream() + .filter(key -> !headers.containsKey(key)) + .forEach(header -> headers.addAll(header, toAdd.get(header))); } private Set filterHeaderKeys(HttpHeaders headers) { diff --git a/spring-cloud-gateway-webflux/src/test/java/org/springframework/cloud/gateway/webflux/ProductionConfigurationTests.java b/spring-cloud-gateway-webflux/src/test/java/org/springframework/cloud/gateway/webflux/ProductionConfigurationTests.java index f250e9939..f2cf0746c 100644 --- a/spring-cloud-gateway-webflux/src/test/java/org/springframework/cloud/gateway/webflux/ProductionConfigurationTests.java +++ b/spring-cloud-gateway-webflux/src/test/java/org/springframework/cloud/gateway/webflux/ProductionConfigurationTests.java @@ -178,13 +178,16 @@ public class ProductionConfigurationTests { } @Test + @SuppressWarnings({"Duplicates", "unchecked"}) public void headers() throws Exception { Map> headers = rest.exchange(RequestEntity.get(rest.getRestTemplate().getUriTemplateHandler() .expand("/proxy/headers")).header("foo", "bar").header("abc", "xyz").build(), Map.class).getBody(); - assertTrue(!headers.containsKey("foo")); - assertTrue(!headers.containsKey("hello")); - assertEquals("hello", headers.get("bar")); - assertEquals("123", headers.get("abc")); + assertThat(headers).doesNotContainKey("foo") + .doesNotContainKey("hello") + .containsKeys("bar", "abc"); + + assertThat(headers.get("bar")).containsOnly("hello"); + assertThat(headers.get("abc")).containsOnly("123"); } @SpringBootApplication @@ -282,14 +285,13 @@ public class ProductionConfigurationTests { } @GetMapping("/proxy/headers") - public Mono>> headers(ProxyExchange> proxy) { + public Mono>>> headers(ProxyExchange>> proxy) { proxy.sensitive("foo"); proxy.sensitive("hello"); proxy.header("bar", "hello"); proxy.header("abc", "123"); proxy.header("hello", "world"); return proxy.uri(home.toString() + "/headers").get(); - } private ResponseEntity first(ResponseEntity> response) { @@ -332,8 +334,8 @@ public class ProductionConfigurationTests { } @GetMapping("/headers") - public Map headers(@RequestHeader HttpHeaders headers) { - return headers.toSingleValueMap(); + public Map> headers(@RequestHeader HttpHeaders headers) { + return headers; } }