diff --git a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/NettyRoutingFilter.java b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/NettyRoutingFilter.java index b8749bdc3..cba9d9fda 100644 --- a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/NettyRoutingFilter.java +++ b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/NettyRoutingFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2017 the original author or authors. + * Copyright 2013-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,14 @@ package org.springframework.cloud.gateway.filter; import java.net.URI; import java.util.List; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter.Type; +import reactor.core.publisher.Mono; +import reactor.ipc.netty.NettyPipeline; +import reactor.ipc.netty.http.client.HttpClient; +import reactor.ipc.netty.http.client.HttpClientRequest; + import org.springframework.beans.factory.ObjectProvider; import org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter; import org.springframework.core.Ordered; @@ -30,21 +38,16 @@ import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.web.server.ServerWebExchange; +import static org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter.filterRequest; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CLIENT_RESPONSE_ATTR; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.PRESERVE_HOST_HEADER_ATTRIBUTE; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.isAlreadyRouted; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.setAlreadyRouted; -import io.netty.handler.codec.http.DefaultHttpHeaders; -import io.netty.handler.codec.http.HttpMethod; -import reactor.core.publisher.Mono; -import reactor.ipc.netty.NettyPipeline; -import reactor.ipc.netty.http.client.HttpClient; -import reactor.ipc.netty.http.client.HttpClientRequest; - /** * @author Spencer Gibb + * @author Biju Kunjummen */ public class NettyRoutingFilter implements GlobalFilter, Ordered { @@ -52,7 +55,7 @@ public class NettyRoutingFilter implements GlobalFilter, Ordered { private final ObjectProvider> headersFilters; public NettyRoutingFilter(HttpClient httpClient, - ObjectProvider> headersFilters) { + ObjectProvider> headersFilters) { this.httpClient = httpClient; this.headersFilters = headersFilters; } @@ -77,8 +80,8 @@ public class NettyRoutingFilter implements GlobalFilter, Ordered { final HttpMethod method = HttpMethod.valueOf(request.getMethod().toString()); final String url = requestUrl.toString(); - HttpHeaders filtered = HttpHeadersFilter.filter(this.headersFilters.getIfAvailable(), - request); + HttpHeaders filtered = filterRequest(this.headersFilters.getIfAvailable(), + exchange); final DefaultHttpHeaders httpHeaders = new DefaultHttpHeaders(); filtered.forEach(httpHeaders::set); @@ -107,9 +110,13 @@ public class NettyRoutingFilter implements GlobalFilter, Ordered { ServerHttpResponse response = exchange.getResponse(); // put headers and status so filters can modify the response HttpHeaders headers = new HttpHeaders(); + res.responseHeaders().forEach(entry -> headers.add(entry.getKey(), entry.getValue())); - response.getHeaders().putAll(headers); + HttpHeaders filteredResponseHeaders = HttpHeadersFilter.filter( + this.headersFilters.getIfAvailable(), headers, exchange, Type.RESPONSE); + + response.getHeaders().putAll(filteredResponseHeaders); response.setStatusCode(HttpStatus.valueOf(res.status().code())); // Defer committing the response until all route filters have run diff --git a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/WebsocketRoutingFilter.java b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/WebsocketRoutingFilter.java index 455f62d7a..e84f68257 100644 --- a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/WebsocketRoutingFilter.java +++ b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/WebsocketRoutingFilter.java @@ -20,6 +20,7 @@ import org.springframework.web.reactive.socket.client.WebSocketClient; import org.springframework.web.reactive.socket.server.WebSocketService; import org.springframework.web.server.ServerWebExchange; +import static org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter.filterRequest; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.isAlreadyRouted; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.setAlreadyRouted; @@ -60,8 +61,8 @@ public class WebsocketRoutingFilter implements GlobalFilter, Ordered { HttpHeaders headers = exchange.getRequest().getHeaders(); - HttpHeaders filtered = HttpHeadersFilter.filter(getHeadersFilters(), - exchange.getRequest()); + HttpHeaders filtered = filterRequest(getHeadersFilters(), + exchange); List protocols = headers.get(SEC_WEBSOCKET_PROTOCOL); if (protocols != null) { @@ -82,9 +83,9 @@ public class WebsocketRoutingFilter implements GlobalFilter, Ordered { filters = new ArrayList<>(); } - filters.add(request -> { + filters.add((headers, exchange) -> { HttpHeaders filtered = new HttpHeaders(); - request.getHeaders().entrySet().stream() + headers.entrySet().stream() .filter(entry -> !entry.getKey().toLowerCase().startsWith("sec-websocket")) .forEach(header -> filtered.addAll(header.getKey(), header.getValue())); return filtered; diff --git a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/ForwardedHeadersFilter.java b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/ForwardedHeadersFilter.java index 98053dc33..b3ffcf46e 100644 --- a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/ForwardedHeadersFilter.java +++ b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/ForwardedHeadersFilter.java @@ -33,6 +33,7 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; public class ForwardedHeadersFilter implements HttpHeadersFilter, Ordered { @@ -44,8 +45,9 @@ public class ForwardedHeadersFilter implements HttpHeadersFilter, Ordered { } @Override - public HttpHeaders filter(ServerHttpRequest request) { - HttpHeaders original = request.getHeaders(); + public HttpHeaders filter(HttpHeaders input, ServerWebExchange exchange) { + ServerHttpRequest request = exchange.getRequest(); + HttpHeaders original = input; HttpHeaders updated = new HttpHeaders(); // copy all headers except Forwarded diff --git a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/HttpHeadersFilter.java b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/HttpHeadersFilter.java index f5bc86f6a..d23015728 100644 --- a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/HttpHeadersFilter.java +++ b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/HttpHeadersFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2018 the original author or authors. + * Copyright 2017-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,27 +17,55 @@ package org.springframework.cloud.gateway.filter.headers; -import org.springframework.http.HttpHeaders; -import org.springframework.http.server.reactive.ServerHttpRequest; - import java.util.List; -@FunctionalInterface +import org.springframework.http.HttpHeaders; +import org.springframework.web.server.ServerWebExchange; + public interface HttpHeadersFilter { - HttpHeaders filter(ServerHttpRequest request); + enum Type { + REQUEST, RESPONSE + } + + /** + * Filters a set of Http Headers + * + * @param input Http Headers + * @param exchange + * @return filtered Http Headers + */ + HttpHeaders filter(HttpHeaders input, ServerWebExchange exchange); - static HttpHeaders filter(final List filters, final ServerHttpRequest request) { - ServerHttpRequest clonedReq = request.mutate().build(); + static HttpHeaders filterRequest(List filters, + ServerWebExchange exchange) { + HttpHeaders headers = exchange.getRequest().getHeaders(); + return filter(filters, headers, exchange, Type.REQUEST); + } + + static HttpHeaders filter(List filters, HttpHeaders input, + ServerWebExchange exchange, Type type) { + HttpHeaders response = input; if (filters != null) { - for (HttpHeadersFilter filter: filters) { - HttpHeaders filtered = filter.filter(clonedReq); - clonedReq = clonedReq.mutate().headers(httpHeaders -> { - httpHeaders.clear(); - httpHeaders.putAll(filtered); - }).build(); - } + HttpHeaders reduce = filters.stream() + .filter(headersFilter -> headersFilter.supports(type)) + .reduce(input, + (headers, filter) -> filter.filter(headers, exchange), + (httpHeaders, httpHeaders2) -> { + httpHeaders.addAll(httpHeaders2); + return httpHeaders; + }); + return reduce; + } + + return response; + } + + default boolean supports(Type type) { + if (type.equals(Type.REQUEST)) { + return true; } - return clonedReq.getHeaders(); + + return false; } } diff --git a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/RemoveHopByHopHeadersFilter.java b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/RemoveHopByHopHeadersFilter.java index 651426f56..5c74271e9 100644 --- a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/RemoveHopByHopHeadersFilter.java +++ b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/RemoveHopByHopHeadersFilter.java @@ -19,13 +19,12 @@ package org.springframework.cloud.gateway.filter.headers; import java.util.Arrays; import java.util.HashSet; -import java.util.List; import java.util.Set; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.core.Ordered; import org.springframework.http.HttpHeaders; -import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.web.server.ServerWebExchange; @ConfigurationProperties("spring.cloud.gateway.filter.remove-hop-by-hop") public class RemoveHopByHopHeadersFilter implements HttpHeadersFilter, Ordered { @@ -68,17 +67,19 @@ public class RemoveHopByHopHeadersFilter implements HttpHeadersFilter, Ordered { } @Override - public HttpHeaders filter(ServerHttpRequest request) { - HttpHeaders original = request.getHeaders(); + public HttpHeaders filter(HttpHeaders input, ServerWebExchange exchange) { HttpHeaders filtered = new HttpHeaders(); - List connection = original.getConnection(); - Set toFilter = new HashSet<>(connection); - toFilter.addAll(this.headers); - - original.entrySet().stream() - .filter(entry -> !toFilter.contains(entry.getKey().toLowerCase())) + + input.entrySet().stream() + .filter(entry -> !this.headers.contains(entry.getKey().toLowerCase())) .forEach(entry -> filtered.addAll(entry.getKey(), entry.getValue())); return filtered; } + + @Override + public boolean supports(Type type) { + return type.equals(Type.REQUEST) || + type.equals(Type.RESPONSE); + } } diff --git a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilter.java b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilter.java index f257013b0..67727fc9f 100644 --- a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilter.java +++ b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilter.java @@ -24,6 +24,7 @@ import org.springframework.core.Ordered; import org.springframework.http.HttpHeaders; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; @ConfigurationProperties("spring.cloud.gateway.x-forwarded") public class XForwardedHeadersFilter implements HttpHeadersFilter, Ordered { @@ -163,8 +164,9 @@ public class XForwardedHeadersFilter implements HttpHeadersFilter, Ordered { } @Override - public HttpHeaders filter(ServerHttpRequest request) { - HttpHeaders original = request.getHeaders(); + public HttpHeaders filter(HttpHeaders input, ServerWebExchange exchange) { + ServerHttpRequest request = exchange.getRequest(); + HttpHeaders original = input; HttpHeaders updated = new HttpHeaders(); original.entrySet().stream() diff --git a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/ForwardedHeadersFilterTests.java b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/ForwardedHeadersFilterTests.java index a8dc95fc9..8a756bc19 100644 --- a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/ForwardedHeadersFilterTests.java +++ b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/ForwardedHeadersFilterTests.java @@ -31,6 +31,7 @@ import org.junit.Test; import org.springframework.cloud.gateway.filter.headers.ForwardedHeadersFilter.Forwarded; import org.springframework.http.HttpHeaders; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -51,7 +52,7 @@ public class ForwardedHeadersFilterTests { ForwardedHeadersFilter filter = new ForwardedHeadersFilter(); - HttpHeaders headers = filter.filter(request); + HttpHeaders headers = filter.filter(request.getHeaders(), MockServerWebExchange.from(request)); assertThat(headers.get(FORWARDED_HEADER)).hasSize(1); diff --git a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/HttpHeadersFilterMixedTypeTests.java b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/HttpHeadersFilterMixedTypeTests.java new file mode 100644 index 000000000..02acc8e5f --- /dev/null +++ b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/HttpHeadersFilterMixedTypeTests.java @@ -0,0 +1,79 @@ +/* + * Copyright 2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.springframework.cloud.gateway.filter.headers; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import org.junit.Test; + +import org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter.Type; +import org.springframework.http.HttpHeaders; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.web.server.ServerWebExchange; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Biju Kunjummen + */ +public class HttpHeadersFilterMixedTypeTests { + + @Test + public void relevantDownstreamFiltersShouldActOnHeaders() { + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("/get") + .header("header1", "value1").header("header2", "value2") + .header("header3", "value3").build(); + + HttpHeadersFilter filter1 = filterRemovingHeaders(Type.RESPONSE, + "header1"); + + HttpHeadersFilter filter2 = filterRemovingHeaders(Type.REQUEST, + "header2"); + + HttpHeaders result = HttpHeadersFilter.filterRequest(Arrays.asList(filter1, filter2), + MockServerWebExchange.from(mockRequest)); + + assertThat(result).containsOnlyKeys("header1", "header3"); + } + + private HttpHeadersFilter filterRemovingHeaders(Type type, + String... headerNames) { + Set headerNamesSet = new HashSet<>(Arrays.asList(headerNames)); + HttpHeadersFilter filter = new HttpHeadersFilter() { + @Override + public HttpHeaders filter(HttpHeaders headers, ServerWebExchange exchange) { + HttpHeaders result = new HttpHeaders(); + headers.entrySet().forEach(entry -> { + if (!headerNamesSet.contains(entry.getKey())) { + result.put(entry.getKey(), entry.getValue()); + } + }); + return result; + } + + @Override + public boolean supports(Type path) { + return path.equals(type); + } + }; + return filter; + } +} diff --git a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/HttpHeadersFilterTests.java b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/HttpHeadersFilterTests.java index cefb37cdf..f0ae66619 100644 --- a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/HttpHeadersFilterTests.java +++ b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/HttpHeadersFilterTests.java @@ -21,43 +21,40 @@ import java.util.Arrays; import java.util.List; import org.junit.Test; + import org.springframework.http.HttpHeaders; -import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; import static org.assertj.core.api.Assertions.assertThat; /** * @author Spencer Gibb + * @author Biju Kunjummen */ public class HttpHeadersFilterTests { @Test public void httpHeadersFilterTests() { MockServerHttpRequest request = MockServerHttpRequest - .get("http://localhost:8080/get") - .header("X-A", "aValue") - .header("X-B", "bValue") - .header("X-C", "cValue") - .build(); + .get("http://localhost:8080/get").header("X-A", "aValue") + .header("X-B", "bValue").header("X-C", "cValue").build(); List filters = Arrays.asList( - r -> HttpHeadersFilterTests.this.filter(r, "X-A"), - r -> HttpHeadersFilterTests.this.filter(r, "X-B") - ); + (h, e) -> HttpHeadersFilterTests.this.filter(h, "X-A"), + (h, e) -> HttpHeadersFilterTests.this.filter(h, "X-B")); - HttpHeaders headers = HttpHeadersFilter.filter(filters, request); + HttpHeaders headers = HttpHeadersFilter.filterRequest(filters, + MockServerWebExchange.from(request)); assertThat(headers).containsOnlyKeys("X-C"); } - private HttpHeaders filter(ServerHttpRequest request, String keyToFilter) { - HttpHeaders original = request.getHeaders(); + private HttpHeaders filter(HttpHeaders input, String keyToFilter) { HttpHeaders filtered = new HttpHeaders(); - original.entrySet().stream() - .filter(entry -> !entry.getKey().equals(keyToFilter)) - .forEach(entry -> filtered.addAll(entry.getKey(), entry.getValue())); + input.entrySet().stream().filter(entry -> !entry.getKey().equals(keyToFilter)) + .forEach(entry -> filtered.addAll(entry.getKey(), entry.getValue())); return filtered; } diff --git a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/RemoveHopByHopHeadersFilterTests.java b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/RemoveHopByHopHeadersFilterTests.java index 7846dac7d..af385eb92 100644 --- a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/RemoveHopByHopHeadersFilterTests.java +++ b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/RemoveHopByHopHeadersFilterTests.java @@ -24,6 +24,7 @@ import java.util.Set; import org.junit.Test; import org.springframework.http.HttpHeaders; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.cloud.gateway.filter.headers.RemoveHopByHopHeadersFilter.HEADERS_REMOVED_ON_REQUEST; @@ -37,10 +38,10 @@ public class RemoveHopByHopHeadersFilterTests { public void happyPath() { MockServerHttpRequest.BaseBuilder builder = MockServerHttpRequest .get("http://localhost/get"); - + HEADERS_REMOVED_ON_REQUEST.forEach(header -> builder.header(header, header+"1")); - testFilter(builder.build()); + testFilter(MockServerWebExchange.from(builder)); } @Test @@ -50,7 +51,7 @@ public class RemoveHopByHopHeadersFilterTests { HEADERS_REMOVED_ON_REQUEST.forEach(header -> builder.header(header.toLowerCase(), header+"1")); - testFilter(builder.build()); + testFilter(MockServerWebExchange.from(builder)); } @Test @@ -62,12 +63,12 @@ public class RemoveHopByHopHeadersFilterTests { builder.header(HttpHeaders.UPGRADE, "WebSocket"); builder.header("Keep-Alive", "timeout:5"); - testFilter(builder.build(), "upgrade", "keep-alive"); + testFilter(MockServerWebExchange.from(builder), "upgrade", "keep-alive"); } - private void testFilter(MockServerHttpRequest request, String... additionalHeaders) { + private void testFilter(MockServerWebExchange exchange, String... additionalHeaders) { RemoveHopByHopHeadersFilter filter = new RemoveHopByHopHeadersFilter(); - HttpHeaders headers = filter.filter(request); + HttpHeaders headers = filter.filter(exchange.getRequest().getHeaders(), exchange); Set toRemove = new HashSet<>(HEADERS_REMOVED_ON_REQUEST); toRemove.addAll(Arrays.asList(additionalHeaders)); diff --git a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilterTests.java b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilterTests.java index 8c112b7a4..a210a962e 100644 --- a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilterTests.java +++ b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilterTests.java @@ -24,6 +24,7 @@ import org.junit.Test; import org.springframework.http.HttpHeaders; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.cloud.gateway.filter.headers.XForwardedHeadersFilter.X_FORWARDED_FOR_HEADER; @@ -46,7 +47,7 @@ public class XForwardedHeadersFilterTests { XForwardedHeadersFilter filter = new XForwardedHeadersFilter(); - HttpHeaders headers = filter.filter(request); + HttpHeaders headers = filter.filter(request.getHeaders(), MockServerWebExchange.from(request)); assertThat(headers).containsKeys(X_FORWARDED_FOR_HEADER, X_FORWARDED_HOST_HEADER, X_FORWARDED_PORT_HEADER, X_FORWARDED_PROTO_HEADER); @@ -67,7 +68,7 @@ public class XForwardedHeadersFilterTests { XForwardedHeadersFilter filter = new XForwardedHeadersFilter(); - HttpHeaders headers = filter.filter(request); + HttpHeaders headers = filter.filter(request.getHeaders(), MockServerWebExchange.from(request)); assertThat(headers).containsKeys(X_FORWARDED_FOR_HEADER, X_FORWARDED_HOST_HEADER, X_FORWARDED_PORT_HEADER, X_FORWARDED_PROTO_HEADER); @@ -91,7 +92,7 @@ public class XForwardedHeadersFilterTests { XForwardedHeadersFilter filter = new XForwardedHeadersFilter(); - HttpHeaders headers = filter.filter(request); + HttpHeaders headers = filter.filter(request.getHeaders(), MockServerWebExchange.from(request)); assertThat(headers).containsKeys(X_FORWARDED_FOR_HEADER, X_FORWARDED_HOST_HEADER, X_FORWARDED_PORT_HEADER, X_FORWARDED_PROTO_HEADER); @@ -119,7 +120,7 @@ public class XForwardedHeadersFilterTests { filter.setPortAppend(false); filter.setProtoAppend(false); - HttpHeaders headers = filter.filter(request); + HttpHeaders headers = filter.filter(request.getHeaders(), MockServerWebExchange.from(request)); assertThat(headers).containsKeys(X_FORWARDED_FOR_HEADER, X_FORWARDED_HOST_HEADER, X_FORWARDED_PORT_HEADER, X_FORWARDED_PROTO_HEADER); @@ -143,7 +144,7 @@ public class XForwardedHeadersFilterTests { filter.setPortEnabled(false); filter.setProtoEnabled(false); - HttpHeaders headers = filter.filter(request); + HttpHeaders headers = filter.filter(request.getHeaders(), MockServerWebExchange.from(request)); assertThat(headers).isEmpty(); } diff --git a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/handler/predicate/PathRoutePredicateFactoryTests.java b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/handler/predicate/PathRoutePredicateFactoryTests.java index bd8a2715e..b3709e328 100644 --- a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/handler/predicate/PathRoutePredicateFactoryTests.java +++ b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/handler/predicate/PathRoutePredicateFactoryTests.java @@ -41,6 +41,7 @@ public class PathRoutePredicateFactoryTests extends BaseWebClientTests { .exchange() .expectStatus().isOk() .expectHeader().valueEquals(HANDLER_MAPPER_HEADER, RoutePredicateHandlerMapping.class.getSimpleName()) + .expectHeader().valueEquals("transfer-encoding", "chunked") .expectHeader().valueEquals(ROUTE_ID_HEADER, "default_path_to_httpbin"); }