diff --git a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilter.java b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilter.java index 0e32377a7..789ba171c 100644 --- a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilter.java +++ b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilter.java @@ -17,8 +17,9 @@ package org.springframework.cloud.gateway.filter.headers; +import java.net.URI; +import java.util.LinkedHashSet; import java.util.List; - import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.core.Ordered; import org.springframework.http.HttpHeaders; @@ -26,6 +27,9 @@ import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.util.StringUtils; import org.springframework.web.server.ServerWebExchange; +import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ORIGINAL_REQUEST_URL_ATTR; +import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR; + @ConfigurationProperties("spring.cloud.gateway.x-forwarded") public class XForwardedHeadersFilter implements HttpHeadersFilter, Ordered { /** default http port */ @@ -52,6 +56,10 @@ public class XForwardedHeadersFilter implements HttpHeadersFilter, Ordered { /** X-Forwarded-Proto Header */ public static final String X_FORWARDED_PROTO_HEADER = "X-Forwarded-Proto"; + /** X-Forwarded-Prefix Header */ + public static final String X_FORWARDED_PREFIX_HEADER = "X-Forwarded-Prefix"; + + /** The order of the XForwardedHeadersFilter. */ private int order = 0; @@ -70,6 +78,9 @@ public class XForwardedHeadersFilter implements HttpHeadersFilter, Ordered { /** If X-Forwarded-Proto is enabled. */ private boolean protoEnabled = true; + /** If X-Forwarded-Prefix is enabled. */ + private boolean prefixEnabled = true; + /** If appending X-Forwarded-For as a list is enabled. */ private boolean forAppend = true; @@ -82,6 +93,9 @@ public class XForwardedHeadersFilter implements HttpHeadersFilter, Ordered { /** If appending X-Forwarded-Proto as a list is enabled. */ private boolean protoAppend = true; + /** If appending X-Forwarded-Prefix as a list is enabled. */ + private boolean prefixAppend = true; + @Override public int getOrder() { return this.order; @@ -131,6 +145,14 @@ public class XForwardedHeadersFilter implements HttpHeadersFilter, Ordered { this.protoEnabled = protoEnabled; } + public boolean isPrefixEnabled() { + return prefixEnabled; + } + + public void setPrefixEnabled(boolean prefixEnabled) { + this.prefixEnabled = prefixEnabled; + } + public boolean isForAppend() { return forAppend; } @@ -163,8 +185,18 @@ public class XForwardedHeadersFilter implements HttpHeadersFilter, Ordered { this.protoAppend = protoAppend; } + public void setPrefixAppend(boolean prefixAppend) { + this.prefixAppend = prefixAppend; + } + + public boolean isPrefixAppend() { + return prefixAppend; + } + @Override public HttpHeaders filter(HttpHeaders input, ServerWebExchange exchange) { + + ServerHttpRequest request = exchange.getRequest(); HttpHeaders original = input; HttpHeaders updated = new HttpHeaders(); @@ -189,6 +221,37 @@ public class XForwardedHeadersFilter implements HttpHeadersFilter, Ordered { write(updated, X_FORWARDED_PROTO_HEADER, proto, isProtoAppend()); } + if(isPrefixEnabled()) { + //if the path of the url that the gw is routing to is a subset (and ending part) of the url that it is routing from then the difference is the prefix + //e.g. if request original.com/prefix/get/ is routed to routedservice:8090/get then /prefix is the prefix - see XForwardedHeadersFilterTests + //so first get uris, then extract paths and remove one from another if it's the ending part + + LinkedHashSet originalUris = exchange.getAttribute(GATEWAY_ORIGINAL_REQUEST_URL_ATTR); + URI requestUri = exchange.getAttribute(GATEWAY_REQUEST_URL_ATTR); + + if(originalUris != null && requestUri != null) { + + originalUris.stream().forEach(originalUri -> { + + if(originalUri!=null && originalUri.getPath()!=null) { + String prefix = originalUri.getPath(); + + //strip trailing slashes before checking if request path is end of original path + String originalUriPath = stripTrailingSlash(originalUri); + String requestUriPath = stripTrailingSlash(requestUri); + + if(requestUriPath!=null && (originalUriPath.endsWith(requestUriPath))) { + prefix = originalUriPath.replace(requestUriPath, ""); + } + if (prefix != null && prefix.length() > 0 && + prefix.length() < originalUri.getPath().length()) { + write(updated, X_FORWARDED_PREFIX_HEADER, prefix, isPrefixAppend()); + } + } + }); + } + } + if (isPortEnabled()) { String port = String.valueOf(request.getURI().getPort()); if (request.getURI().getPort() < 0) { @@ -240,4 +303,12 @@ public class XForwardedHeadersFilter implements HttpHeadersFilter, Ordered { return host + ":" + port; } } -} + + private String stripTrailingSlash(URI uri) { + if (uri.getPath().endsWith("/")) { + return uri.getPath().substring(0, uri.getPath().length() - 1); + } else { + return uri.getPath(); + } + } +} \ No newline at end of file diff --git a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilterTests.java b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilterTests.java index a7376571f..31bade18a 100644 --- a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilterTests.java +++ b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/headers/XForwardedHeadersFilterTests.java @@ -19,18 +19,25 @@ package org.springframework.cloud.gateway.filter.headers; import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.URI; +import java.util.LinkedHashSet; import org.junit.Test; import org.springframework.http.HttpHeaders; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.cloud.gateway.filter.headers.XForwardedHeadersFilter.X_FORWARDED_FOR_HEADER; import static org.springframework.cloud.gateway.filter.headers.XForwardedHeadersFilter.X_FORWARDED_HOST_HEADER; import static org.springframework.cloud.gateway.filter.headers.XForwardedHeadersFilter.X_FORWARDED_PORT_HEADER; +import static org.springframework.cloud.gateway.filter.headers.XForwardedHeadersFilter.X_FORWARDED_PREFIX_HEADER; import static org.springframework.cloud.gateway.filter.headers.XForwardedHeadersFilter.X_FORWARDED_PROTO_HEADER; +import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ORIGINAL_REQUEST_URL_ATTR; +import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR; /** * @author Spencer Gibb @@ -131,6 +138,7 @@ public class XForwardedHeadersFilterTests { .header(X_FORWARDED_HOST_HEADER, "example.com") .header(X_FORWARDED_PORT_HEADER, "443") .header(X_FORWARDED_PROTO_HEADER, "https") + .header(X_FORWARDED_PREFIX_HEADER,"/prefix") .build(); XForwardedHeadersFilter filter = new XForwardedHeadersFilter(); @@ -138,16 +146,98 @@ public class XForwardedHeadersFilterTests { filter.setHostAppend(false); filter.setPortAppend(false); filter.setProtoAppend(false); + filter.setPrefixAppend(false); HttpHeaders headers = filter.filter(request.getHeaders(), MockServerWebExchange.from(request)); assertThat(headers).containsKeys(X_FORWARDED_FOR_HEADER, X_FORWARDED_HOST_HEADER, - X_FORWARDED_PORT_HEADER, X_FORWARDED_PROTO_HEADER); + X_FORWARDED_PORT_HEADER, X_FORWARDED_PROTO_HEADER,X_FORWARDED_PREFIX_HEADER); assertThat(headers.getFirst(X_FORWARDED_FOR_HEADER)).isEqualTo("10.0.0.1"); assertThat(headers.getFirst(X_FORWARDED_HOST_HEADER)).isEqualTo("localhost:8080"); assertThat(headers.getFirst(X_FORWARDED_PORT_HEADER)).isEqualTo("8080"); assertThat(headers.getFirst(X_FORWARDED_PROTO_HEADER)).isEqualTo("http"); + assertThat(headers.getFirst(X_FORWARDED_PREFIX_HEADER)).isEqualTo("/prefix"); + } + + + @Test + public void prefixToInfer() throws Exception { + MockServerHttpRequest request = MockServerHttpRequest + .get("http://originalhost:8080/prefix/get") + .remoteAddress(new InetSocketAddress(InetAddress.getByName("10.0.0.1"), 80)) + .build(); + + XForwardedHeadersFilter filter = new XForwardedHeadersFilter(); + filter.setPrefixAppend(true); + filter.setPrefixEnabled(true); + + ServerWebExchange exchange = MockServerWebExchange.from(request); + LinkedHashSet originalUris = new LinkedHashSet<>(); + originalUris.add(UriComponentsBuilder.fromUriString("http://originalhost:8080/prefix/get/").build().toUri()); //trailing slash + exchange.getAttributes().put(GATEWAY_ORIGINAL_REQUEST_URL_ATTR, originalUris); + URI requestUri = UriComponentsBuilder.fromUriString("http://routedservice:8090/get").build().toUri(); + exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, requestUri); + + HttpHeaders headers = filter.filter(request.getHeaders(), exchange); + + assertThat(headers).containsKeys(X_FORWARDED_PREFIX_HEADER); + + assertThat(headers.getFirst(X_FORWARDED_PREFIX_HEADER)).isEqualTo("/prefix"); + } + + @Test + public void noPrefixToInfer() throws Exception { + MockServerHttpRequest request = MockServerHttpRequest + .get("http://originalhost:8080/get") + .remoteAddress(new InetSocketAddress(InetAddress.getByName("10.0.0.1"), 80)) + .build(); + + XForwardedHeadersFilter filter = new XForwardedHeadersFilter(); + filter.setPrefixAppend(true); + filter.setPrefixEnabled(true); + filter.setForEnabled(false); + filter.setHostEnabled(false); + filter.setPortEnabled(false); + filter.setProtoEnabled(false); + + ServerWebExchange exchange = MockServerWebExchange.from(request); + LinkedHashSet originalUris = new LinkedHashSet<>(); + originalUris.add(UriComponentsBuilder.fromUriString("http://originalhost:8080/get/").build().toUri()); + exchange.getAttributes().put(GATEWAY_ORIGINAL_REQUEST_URL_ATTR, originalUris); + URI requestUri = UriComponentsBuilder.fromUriString("http://routedservice:8090/get").build().toUri(); + exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, requestUri); + + HttpHeaders headers = filter.filter(request.getHeaders(), exchange); + + assertThat(headers).isEmpty(); + } + + @Test + public void routedPathInRequestPathButNotPrefix() throws Exception { + MockServerHttpRequest request = MockServerHttpRequest + .get("http://originalhost:8080/get") + .remoteAddress(new InetSocketAddress(InetAddress.getByName("10.0.0.1"), 80)) + .build(); + + XForwardedHeadersFilter filter = new XForwardedHeadersFilter(); + filter.setPrefixAppend(true); + filter.setPrefixEnabled(true); + filter.setForEnabled(false); + filter.setHostEnabled(false); + filter.setPortEnabled(false); + filter.setProtoEnabled(false); + + ServerWebExchange exchange = MockServerWebExchange.from(request); + LinkedHashSet originalUris = new LinkedHashSet<>(); + originalUris.add(UriComponentsBuilder.fromUriString("http://originalhost:8080/one/two/three").build().toUri()); + exchange.getAttributes().put(GATEWAY_ORIGINAL_REQUEST_URL_ATTR, originalUris); + URI requestUri = UriComponentsBuilder.fromUriString("http://routedservice:8090/two").build().toUri(); + exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, requestUri); + + HttpHeaders headers = filter.filter(request.getHeaders(), exchange); + + assertThat(headers).isEmpty(); } @Test @@ -162,6 +252,7 @@ public class XForwardedHeadersFilterTests { filter.setHostEnabled(false); filter.setPortEnabled(false); filter.setProtoEnabled(false); + filter.setPrefixEnabled(false); HttpHeaders headers = filter.filter(request.getHeaders(), MockServerWebExchange.from(request));