diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index 26ed66c315..e0af31c56d 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -31,6 +31,7 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.http.HttpRequest; import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.util.CollectionUtils; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; @@ -92,6 +93,10 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWrapper { + public static final Enumeration EMPTY_HEADER_VALUES = + Collections.enumeration(Collections.emptyList()); + + private final String scheme; private final boolean secure; @@ -100,9 +105,9 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { private final int port; - private final String portInUrl; + private final StringBuffer requestUrl; - private final Map> headers = new LinkedHashMap>(); + private final Map> headers; public ForwardedHeaderRequestWrapper(HttpServletRequest request) { @@ -116,51 +121,35 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { this.secure = "https".equals(scheme); this.host = uriComponents.getHost(); this.port = (port == -1 ? (this.secure ? 443 : 80) : port); - this.portInUrl = (port == -1 ? "" : ":" + port); + this.requestUrl = initRequestUrl(this.scheme, this.host, port, request.getRequestURI()); + this.headers = initHeaders(request); + } + private static StringBuffer initRequestUrl(String scheme, String host, int port, String path) { + StringBuffer sb = new StringBuffer(); + sb.append(scheme).append("://").append(host); + sb.append(port == -1 ? "" : ":" + port); + sb.append(path); + return sb; + } + + /** + * Copy the headers excluding any {@link #FORWARDED_HEADER_NAMES}. + */ + private static Map> initHeaders(HttpServletRequest request) { + Map> headers = new LinkedHashMap>(); Enumeration headerNames = request.getHeaderNames(); while (headerNames.hasMoreElements()) { String name = headerNames.nextElement(); - this.headers.put(name, Collections.list(request.getHeaders(name))); + headers.put(name, Collections.list(request.getHeaders(name))); } for (String name : FORWARDED_HEADER_NAMES) { - this.headers.remove(name); + headers.remove(name); } + return headers; } - @Override - public String getHeader(String name) { - Map.Entry> header = getHeaderEntry(name); - if (header == null || header.getValue() == null || header.getValue().isEmpty()) { - return null; - } - return header.getValue().get(0); - } - - protected Map.Entry> getHeaderEntry(String name) { - for (Map.Entry> entry : this.headers.entrySet()) { - if (entry.getKey().equalsIgnoreCase(name)) { - return entry; - } - } - return null; - } - - @Override - public Enumeration getHeaderNames() { - return Collections.enumeration(this.headers.keySet()); - } - - @Override - public Enumeration getHeaders(String name) { - Map.Entry> header = getHeaderEntry(name); - if (header == null || header.getValue() == null) { - return Collections.enumeration(Collections.emptyList()); - } - return Collections.enumeration(header.getValue()); - } - @Override public String getScheme() { return this.scheme; @@ -183,10 +172,26 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @Override public StringBuffer getRequestURL() { - StringBuffer sb = new StringBuffer(); - sb.append(this.scheme).append("://").append(this.host).append(this.portInUrl); - sb.append(getRequestURI()); - return sb; + return this.requestUrl; + } + + // Override header accessors in order to not expose forwarded headers + + @Override + public String getHeader(String name) { + List value = this.headers.get(name); + return (CollectionUtils.isEmpty(value) ? null : value.get(0)); + } + + @Override + public Enumeration getHeaders(String name) { + List value = this.headers.get(name); + return (CollectionUtils.isEmpty(value) ? EMPTY_HEADER_VALUES : Collections.enumeration(value)); + } + + @Override + public Enumeration getHeaderNames() { + return Collections.enumeration(this.headers.keySet()); } } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java index 207af6518c..62f20c71bb 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java @@ -53,7 +53,7 @@ public class ForwardedHeaderFilterTests { } @Test - public void xForwardedHeaders() throws Exception { + public void forwardedRequest() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest(); request.setScheme("http"); request.setServerName("localhost"); @@ -65,10 +65,9 @@ public class ForwardedHeaderFilterTests { request.addHeader("foo", "bar"); MockFilterChain chain = new MockFilterChain(new HttpServlet() {}); - this.filter.doFilter(request, new MockHttpServletResponse(), chain); - HttpServletRequest actual = (HttpServletRequest) chain.getRequest(); + assertEquals("https://84.198.58.199/mvc-showcase", actual.getRequestURL().toString()); assertEquals("https", actual.getScheme()); assertEquals("84.198.58.199", actual.getServerName());