Browse Source

ForwardedHeaderFilter works with Servlet FORWARD

Issue: SPR-16983
pull/1999/head
Rossen Stoyanchev 6 years ago
parent
commit
feeec344e5
  1. 140
      spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java
  2. 24
      spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java

140
spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java

@ -23,6 +23,7 @@ import java.util.List; @@ -23,6 +23,7 @@ import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
@ -219,11 +220,7 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -219,11 +220,7 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
private final int port;
private final String contextPath;
private final String requestUri;
private final String requestUrl;
private final ForwardedPrefixExtractor forwardedPrefixExtractor;
ForwardedHeaderExtractingRequest(HttpServletRequest request, UrlPathHelper pathHelper) {
@ -238,28 +235,9 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -238,28 +235,9 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
this.host = uriComponents.getHost();
this.port = (port == -1 ? (this.secure ? 443 : 80) : port);
String prefix = getForwardedPrefix(request);
this.contextPath = (prefix != null ? prefix : request.getContextPath());
this.requestUri = this.contextPath + pathHelper.getPathWithinApplication(request);
this.requestUrl = this.scheme + "://" + this.host + (port == -1 ? "" : ":" + port) + this.requestUri;
}
@Nullable
private static String getForwardedPrefix(HttpServletRequest request) {
String prefix = null;
Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
if ("X-Forwarded-Prefix".equalsIgnoreCase(name)) {
prefix = request.getHeader(name);
}
}
if (prefix != null) {
while (prefix.endsWith("/")) {
prefix = prefix.substring(0, prefix.length() - 1);
}
}
return prefix;
String baseUrl = this.scheme + "://" + this.host + (port == -1 ? "" : ":" + port);
Supplier<HttpServletRequest> delegateRequest = () -> (HttpServletRequest) getRequest();
this.forwardedPrefixExtractor = new ForwardedPrefixExtractor(delegateRequest, pathHelper, baseUrl);
}
@ -287,18 +265,122 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -287,18 +265,122 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
@Override
public String getContextPath() {
return this.contextPath;
return this.forwardedPrefixExtractor.getContextPath();
}
@Override
public String getRequestURI() {
return this.requestUri;
return this.forwardedPrefixExtractor.getRequestUri();
}
@Override
public StringBuffer getRequestURL() {
return this.forwardedPrefixExtractor.getRequestUrl();
}
}
/**
* Responsible for the contextPath, requestURI, and requestURL with forwarded
* headers in mind, and also taking into account changes to the path of the
* underlying delegate request (e.g. on a Servlet FORWARD).
*/
private static class ForwardedPrefixExtractor {
private final Supplier<HttpServletRequest> delegate;
private final UrlPathHelper pathHelper;
private final String baseUrl;
private String actualRequestUri;
@Nullable
private final String forwardedPrefix;
@Nullable
private String requestUri;
private String requestUrl;
/**
* Constructor with required information.
* @param delegateRequest supplier for the current
* {@link HttpServletRequestWrapper#getRequest() delegate request} which
* may change during a forward (e.g. Tocat.
* @param pathHelper the path helper instance
* @param baseUrl the host, scheme, and port based on forwarded headers
*/
public ForwardedPrefixExtractor(
Supplier<HttpServletRequest> delegateRequest, UrlPathHelper pathHelper, String baseUrl) {
this.delegate = delegateRequest;
this.pathHelper = pathHelper;
this.baseUrl = baseUrl;
this.actualRequestUri = delegateRequest.get().getRequestURI();
this.forwardedPrefix = initForwardedPrefix(delegateRequest.get());
this.requestUri = initRequestUri();
this.requestUrl = initRequestUrl(); // Keep the order: depends on requestUri
}
@Nullable
private static String initForwardedPrefix(HttpServletRequest request) {
String result = null;
Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
if ("X-Forwarded-Prefix".equalsIgnoreCase(name)) {
result = request.getHeader(name);
}
}
if (result != null) {
while (result.endsWith("/")) {
result = result.substring(0, result.length() - 1);
}
}
return result;
}
@Nullable
private String initRequestUri() {
if (this.forwardedPrefix != null) {
return this.forwardedPrefix + this.pathHelper.getPathWithinApplication(this.delegate.get());
}
return null;
}
private String initRequestUrl() {
return this.baseUrl + (this.requestUri != null ? this.requestUri : this.delegate.get().getRequestURI());
}
public String getContextPath() {
return this.forwardedPrefix == null ? this.delegate.get().getContextPath() : this.forwardedPrefix;
}
public String getRequestUri() {
if (this.requestUri == null) {
return this.delegate.get().getRequestURI();
}
recalculatePathsIfNecesary();
return this.requestUri;
}
public StringBuffer getRequestUrl() {
recalculatePathsIfNecesary();
return new StringBuffer(this.requestUrl);
}
private void recalculatePathsIfNecesary() {
if (!this.actualRequestUri.equals(this.delegate.get().getRequestURI())) {
// Underlying path change (e.g. Servlet FORWARD).
this.actualRequestUri = this.delegate.get().getRequestURI();
this.requestUri = initRequestUri();
this.requestUrl = initRequestUrl(); // Keep the order: depends on requestUri
}
}
}

24
spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java

@ -17,7 +17,9 @@ @@ -17,7 +17,9 @@
package org.springframework.web.filter;
import java.io.IOException;
import java.net.URI;
import java.util.Enumeration;
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
@ -308,6 +310,28 @@ public class ForwardedHeaderFilterTests { @@ -308,6 +310,28 @@ public class ForwardedHeaderFilterTests {
assertEquals("bar", actual.getHeader("foo"));
}
@Test // SPR-16983
public void forwardedRequestWithServletForward() throws Exception {
this.request.setRequestURI("/foo");
this.request.addHeader(X_FORWARDED_PROTO, "https");
this.request.addHeader(X_FORWARDED_HOST, "www.mycompany.com");
this.request.addHeader(X_FORWARDED_PORT, "443");
this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain);
HttpServletRequest wrappedRequest = (HttpServletRequest) this.filterChain.getRequest();
this.request.setDispatcherType(DispatcherType.FORWARD);
this.request.setRequestURI("/bar");
this.filterChain.reset();
this.filter.doFilter(wrappedRequest, new MockHttpServletResponse(), this.filterChain);
HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest();
assertNotNull(actual);
assertEquals("/bar", actual.getRequestURI());
assertEquals("https://www.mycompany.com/bar", actual.getRequestURL().toString());
}
@Test
public void requestUriWithForwardedPrefix() throws Exception {
this.request.addHeader(X_FORWARDED_PREFIX, "/prefix");

Loading…
Cancel
Save