Browse Source

Polish WebFlux ForwardedHeaderFilter and tests

Preparation for SPR-17072
pull/1895/merge
Rossen Stoyanchev 6 years ago
parent
commit
41aa4218af
  1. 54
      spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java
  2. 105
      spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java

54
spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java

@ -44,7 +44,7 @@ import org.springframework.web.util.UriComponentsBuilder; @@ -44,7 +44,7 @@ import org.springframework.web.util.UriComponentsBuilder;
*/
public class ForwardedHeaderFilter implements WebFilter {
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
static {
FORWARDED_HEADER_NAMES.add("Forwarded");
@ -72,54 +72,58 @@ public class ForwardedHeaderFilter implements WebFilter { @@ -72,54 +72,58 @@ public class ForwardedHeaderFilter implements WebFilter {
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
if (shouldNotFilter(exchange.getRequest())) {
ServerHttpRequest request = exchange.getRequest();
if (!hasForwardedHeaders(request)) {
return chain.filter(exchange);
}
ServerWebExchange mutatedExchange;
if (this.removeOnly) {
mutatedExchange = exchange.mutate().request(builder ->
builder.headers(headers -> {
FORWARDED_HEADER_NAMES.forEach(headers::remove);
}))
.build();
mutatedExchange = exchange.mutate().request(this::removeForwardedHeaders).build();
}
else {
URI uri = UriComponentsBuilder.fromHttpRequest(exchange.getRequest()).build().toUri();
String prefix = getForwardedPrefix(exchange.getRequest().getHeaders());
mutatedExchange = exchange.mutate().request(builder -> {
builder.uri(uri);
if (prefix != null) {
builder.path(prefix + uri.getPath());
builder.contextPath(prefix);
}
}).build();
mutatedExchange = exchange.mutate()
.request(builder -> {
URI uri = UriComponentsBuilder.fromHttpRequest(request).build().toUri();
builder.uri(uri);
String prefix = getForwardedPrefix(request);
if (prefix != null) {
builder.path(prefix + uri.getPath());
builder.contextPath(prefix);
}
})
.build();
}
return chain.filter(mutatedExchange);
}
private boolean shouldNotFilter(ServerHttpRequest request) {
private boolean hasForwardedHeaders(ServerHttpRequest request) {
HttpHeaders headers = request.getHeaders();
for (String headerName : FORWARDED_HEADER_NAMES) {
if (headers.containsKey(headerName)) {
return false;
return true;
}
}
return true;
return false;
}
@Nullable
private static String getForwardedPrefix(HttpHeaders headers) {
private static String getForwardedPrefix(ServerHttpRequest request) {
HttpHeaders headers = request.getHeaders();
String prefix = headers.getFirst("X-Forwarded-Prefix");
if (prefix != null) {
while (prefix.endsWith("/")) {
prefix = prefix.substring(0, prefix.length() - 1);
}
int endIndex = prefix.length();
while (endIndex > 1 && prefix.charAt(endIndex - 1) == '/') {
endIndex--;
};
prefix = endIndex != prefix.length() ? prefix.substring(0, endIndex) : prefix;
}
return prefix;
}
private ServerHttpRequest.Builder removeForwardedHeaders(ServerHttpRequest.Builder builder) {
return builder.headers(map -> FORWARDED_HEADER_NAMES.forEach(map::remove));
}
}

105
spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java

@ -23,16 +23,19 @@ import org.junit.Test; @@ -23,16 +23,19 @@ import org.junit.Test;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.web.test.server.MockServerWebExchange;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import static org.junit.Assert.*;
import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.*;
/**
* Unit tests for {@link ForwardedHeaderFilter}.
* @author Arjen Poutsma
* @author Rossen Stoyanchev
*/
public class ForwardedHeaderFilterTests {
@ -46,65 +49,65 @@ public class ForwardedHeaderFilterTests { @@ -46,65 +49,65 @@ public class ForwardedHeaderFilterTests {
@Test
public void removeOnly() {
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL)
.header("Forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43")
.header("X-Forwarded-Host", "example.com")
.header("X-Forwarded-Port", "8080")
.header("X-Forwarded-Proto", "http")
.header("X-Forwarded-Prefix", "prefix")
.header("X-Forwarded-Ssl", "on"));
this.filter.setRemoveOnly(true);
this.filter.filter(exchange, this.filterChain).block(Duration.ZERO);
HttpHeaders result = this.filterChain.getHeaders();
assertNotNull(result);
assertFalse(result.containsKey("Forwarded"));
assertFalse(result.containsKey("X-Forwarded-Host"));
assertFalse(result.containsKey("X-Forwarded-Port"));
assertFalse(result.containsKey("X-Forwarded-Proto"));
assertFalse(result.containsKey("X-Forwarded-Prefix"));
assertFalse(result.containsKey("X-Forwarded-Ssl"));
HttpHeaders headers = new HttpHeaders();
headers.add("Forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43");
headers.add("X-Forwarded-Host", "example.com");
headers.add("X-Forwarded-Port", "8080");
headers.add("X-Forwarded-Proto", "http");
headers.add("X-Forwarded-Prefix", "prefix");
headers.add("X-Forwarded-Ssl", "on");
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
this.filterChain.assertForwardedHeadersRemoved();
}
@Test
public void xForwardedRequest() throws Exception {
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL)
.header("X-Forwarded-Host", "84.198.58.199")
.header("X-Forwarded-Port", "443")
.header("X-Forwarded-Proto", "https"));
assertEquals(new URI("https://84.198.58.199/path"), filterAndGetUri(exchange));
public void xForwardedHeaders() throws Exception {
HttpHeaders headers = new HttpHeaders();
headers.add("X-Forwarded-Host", "84.198.58.199");
headers.add("X-Forwarded-Port", "443");
headers.add("X-Forwarded-Proto", "https");
headers.add("foo", "bar");
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
assertEquals(new URI("https://84.198.58.199/path"), this.filterChain.uri);
}
@Test
public void forwardedRequest() throws Exception {
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL)
.header("Forwarded", "host=84.198.58.199;proto=https"));
public void forwardedHeader() throws Exception {
HttpHeaders headers = new HttpHeaders();
headers.add("Forwarded", "host=84.198.58.199;proto=https");
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
assertEquals(new URI("https://84.198.58.199/path"), filterAndGetUri(exchange));
assertEquals(new URI("https://84.198.58.199/path"), this.filterChain.uri);
}
@Test
public void requestUriWithForwardedPrefix() throws Exception {
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL)
.header("X-Forwarded-Prefix", "/prefix"));
public void xForwardedPrefix() throws Exception {
HttpHeaders headers = new HttpHeaders();
headers.add("X-Forwarded-Prefix", "/prefix");
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
assertEquals(new URI("http://example.com/prefix/path"), filterAndGetUri(exchange));
assertEquals(new URI("http://example.com/prefix/path"), this.filterChain.uri);
assertEquals("/prefix/path", this.filterChain.requestPathValue);
}
@Test
public void requestUriWithForwardedPrefixTrailingSlash() throws Exception {
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL)
.header("X-Forwarded-Prefix", "/prefix/"));
public void xForwardedPrefixTrailingSlash() throws Exception {
HttpHeaders headers = new HttpHeaders();
headers.add("X-Forwarded-Prefix", "/prefix////");
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
assertEquals(new URI("http://example.com/prefix/path"), filterAndGetUri(exchange));
assertEquals(new URI("http://example.com/prefix/path"), this.filterChain.uri);
assertEquals("/prefix/path", this.filterChain.requestPathValue);
}
@Nullable
private URI filterAndGetUri(ServerWebExchange exchange) {
this.filter.filter(exchange, this.filterChain).block(Duration.ZERO);
return this.filterChain.uri;
private MockServerWebExchange getExchange(HttpHeaders headers) {
MockServerHttpRequest request = MockServerHttpRequest.get(BASE_URL).headers(headers).build();
return MockServerWebExchange.from(request);
}
@ -116,12 +119,26 @@ public class ForwardedHeaderFilterTests { @@ -116,12 +119,26 @@ public class ForwardedHeaderFilterTests {
@Nullable
private URI uri;
@Nullable String requestPathValue;
@Nullable
public HttpHeaders getHeaders() {
return this.headers;
}
@Nullable
public String getHeader(String name) {
assertNotNull(this.headers);
return this.headers.getFirst(name);
}
public void assertForwardedHeadersRemoved() {
assertNotNull(this.headers);
ForwardedHeaderFilter.FORWARDED_HEADER_NAMES
.forEach(name -> assertFalse(this.headers.containsKey(name)));
}
@Nullable
public URI getUri() {
return this.uri;
@ -129,8 +146,10 @@ public class ForwardedHeaderFilterTests { @@ -129,8 +146,10 @@ public class ForwardedHeaderFilterTests {
@Override
public Mono<Void> filter(ServerWebExchange exchange) {
this.headers = exchange.getRequest().getHeaders();
this.uri = exchange.getRequest().getURI();
ServerHttpRequest request = exchange.getRequest();
this.headers = request.getHeaders();
this.uri = request.getURI();
this.requestPathValue = request.getPath().value();
return Mono.empty();
}
}

Loading…
Cancel
Save