diff --git a/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java index 65b146a469..ebf071b799 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java @@ -44,7 +44,7 @@ import org.springframework.web.util.UriComponentsBuilder; */ public class ForwardedHeaderFilter implements WebFilter { - private static final Set FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5); + static final Set FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5); static { FORWARDED_HEADER_NAMES.add("Forwarded"); @@ -72,54 +72,58 @@ public class ForwardedHeaderFilter implements WebFilter { @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { - if (shouldNotFilter(exchange.getRequest())) { + ServerHttpRequest request = exchange.getRequest(); + if (!hasForwardedHeaders(request)) { return chain.filter(exchange); } ServerWebExchange mutatedExchange; - if (this.removeOnly) { - mutatedExchange = exchange.mutate().request(builder -> - builder.headers(headers -> { - FORWARDED_HEADER_NAMES.forEach(headers::remove); - })) - .build(); + mutatedExchange = exchange.mutate().request(this::removeForwardedHeaders).build(); } else { - URI uri = UriComponentsBuilder.fromHttpRequest(exchange.getRequest()).build().toUri(); - String prefix = getForwardedPrefix(exchange.getRequest().getHeaders()); - - mutatedExchange = exchange.mutate().request(builder -> { - builder.uri(uri); - if (prefix != null) { - builder.path(prefix + uri.getPath()); - builder.contextPath(prefix); - } - }).build(); + mutatedExchange = exchange.mutate() + .request(builder -> { + URI uri = UriComponentsBuilder.fromHttpRequest(request).build().toUri(); + builder.uri(uri); + String prefix = getForwardedPrefix(request); + if (prefix != null) { + builder.path(prefix + uri.getPath()); + builder.contextPath(prefix); + } + }) + .build(); } return chain.filter(mutatedExchange); } - private boolean shouldNotFilter(ServerHttpRequest request) { + private boolean hasForwardedHeaders(ServerHttpRequest request) { HttpHeaders headers = request.getHeaders(); for (String headerName : FORWARDED_HEADER_NAMES) { if (headers.containsKey(headerName)) { - return false; + return true; } } - return true; + return false; } @Nullable - private static String getForwardedPrefix(HttpHeaders headers) { + private static String getForwardedPrefix(ServerHttpRequest request) { + HttpHeaders headers = request.getHeaders(); String prefix = headers.getFirst("X-Forwarded-Prefix"); if (prefix != null) { - while (prefix.endsWith("/")) { - prefix = prefix.substring(0, prefix.length() - 1); - } + int endIndex = prefix.length(); + while (endIndex > 1 && prefix.charAt(endIndex - 1) == '/') { + endIndex--; + }; + prefix = endIndex != prefix.length() ? prefix.substring(0, endIndex) : prefix; } return prefix; } + private ServerHttpRequest.Builder removeForwardedHeaders(ServerHttpRequest.Builder builder) { + return builder.headers(map -> FORWARDED_HEADER_NAMES.forEach(map::remove)); + } + } diff --git a/spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java index 5f5024d692..348cd30e3f 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java @@ -23,16 +23,19 @@ import org.junit.Test; import reactor.core.publisher.Mono; import org.springframework.http.HttpHeaders; +import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.lang.Nullable; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.web.test.server.MockServerWebExchange; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilterChain; import static org.junit.Assert.*; -import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.*; /** + * Unit tests for {@link ForwardedHeaderFilter}. * @author Arjen Poutsma + * @author Rossen Stoyanchev */ public class ForwardedHeaderFilterTests { @@ -46,65 +49,65 @@ public class ForwardedHeaderFilterTests { @Test public void removeOnly() { - ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) - .header("Forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43") - .header("X-Forwarded-Host", "example.com") - .header("X-Forwarded-Port", "8080") - .header("X-Forwarded-Proto", "http") - .header("X-Forwarded-Prefix", "prefix") - .header("X-Forwarded-Ssl", "on")); this.filter.setRemoveOnly(true); - this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); - - HttpHeaders result = this.filterChain.getHeaders(); - assertNotNull(result); - assertFalse(result.containsKey("Forwarded")); - assertFalse(result.containsKey("X-Forwarded-Host")); - assertFalse(result.containsKey("X-Forwarded-Port")); - assertFalse(result.containsKey("X-Forwarded-Proto")); - assertFalse(result.containsKey("X-Forwarded-Prefix")); - assertFalse(result.containsKey("X-Forwarded-Ssl")); + + HttpHeaders headers = new HttpHeaders(); + headers.add("Forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43"); + headers.add("X-Forwarded-Host", "example.com"); + headers.add("X-Forwarded-Port", "8080"); + headers.add("X-Forwarded-Proto", "http"); + headers.add("X-Forwarded-Prefix", "prefix"); + headers.add("X-Forwarded-Ssl", "on"); + this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO); + + this.filterChain.assertForwardedHeadersRemoved(); } @Test - public void xForwardedRequest() throws Exception { - ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) - .header("X-Forwarded-Host", "84.198.58.199") - .header("X-Forwarded-Port", "443") - .header("X-Forwarded-Proto", "https")); - - assertEquals(new URI("https://84.198.58.199/path"), filterAndGetUri(exchange)); + public void xForwardedHeaders() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add("X-Forwarded-Host", "84.198.58.199"); + headers.add("X-Forwarded-Port", "443"); + headers.add("X-Forwarded-Proto", "https"); + headers.add("foo", "bar"); + this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO); + + assertEquals(new URI("https://84.198.58.199/path"), this.filterChain.uri); } @Test - public void forwardedRequest() throws Exception { - ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) - .header("Forwarded", "host=84.198.58.199;proto=https")); + public void forwardedHeader() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add("Forwarded", "host=84.198.58.199;proto=https"); + this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO); - assertEquals(new URI("https://84.198.58.199/path"), filterAndGetUri(exchange)); + assertEquals(new URI("https://84.198.58.199/path"), this.filterChain.uri); } @Test - public void requestUriWithForwardedPrefix() throws Exception { - ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) - .header("X-Forwarded-Prefix", "/prefix")); + public void xForwardedPrefix() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add("X-Forwarded-Prefix", "/prefix"); + this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO); - assertEquals(new URI("http://example.com/prefix/path"), filterAndGetUri(exchange)); + assertEquals(new URI("http://example.com/prefix/path"), this.filterChain.uri); + assertEquals("/prefix/path", this.filterChain.requestPathValue); } @Test - public void requestUriWithForwardedPrefixTrailingSlash() throws Exception { - ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) - .header("X-Forwarded-Prefix", "/prefix/")); + public void xForwardedPrefixTrailingSlash() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add("X-Forwarded-Prefix", "/prefix////"); + this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO); - assertEquals(new URI("http://example.com/prefix/path"), filterAndGetUri(exchange)); + assertEquals(new URI("http://example.com/prefix/path"), this.filterChain.uri); + assertEquals("/prefix/path", this.filterChain.requestPathValue); } - @Nullable - private URI filterAndGetUri(ServerWebExchange exchange) { - this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); - return this.filterChain.uri; + private MockServerWebExchange getExchange(HttpHeaders headers) { + MockServerHttpRequest request = MockServerHttpRequest.get(BASE_URL).headers(headers).build(); + return MockServerWebExchange.from(request); } @@ -116,12 +119,26 @@ public class ForwardedHeaderFilterTests { @Nullable private URI uri; + @Nullable String requestPathValue; + @Nullable public HttpHeaders getHeaders() { return this.headers; } + @Nullable + public String getHeader(String name) { + assertNotNull(this.headers); + return this.headers.getFirst(name); + } + + public void assertForwardedHeadersRemoved() { + assertNotNull(this.headers); + ForwardedHeaderFilter.FORWARDED_HEADER_NAMES + .forEach(name -> assertFalse(this.headers.containsKey(name))); + } + @Nullable public URI getUri() { return this.uri; @@ -129,8 +146,10 @@ public class ForwardedHeaderFilterTests { @Override public Mono filter(ServerWebExchange exchange) { - this.headers = exchange.getRequest().getHeaders(); - this.uri = exchange.getRequest().getURI(); + ServerHttpRequest request = exchange.getRequest(); + this.headers = request.getHeaders(); + this.uri = request.getURI(); + this.requestPathValue = request.getPath().value(); return Mono.empty(); } }