Browse Source

Polish ForwardedHeaderFilter

Issue: SPR-13614
pull/985/head
Rossen Stoyanchev 9 years ago
parent
commit
6fcc869338
  1. 87
      spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java
  2. 5
      spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java

87
spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java

@ -31,6 +31,7 @@ import javax.servlet.http.HttpServletResponse; @@ -31,6 +31,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpRequest;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.util.CollectionUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
@ -92,6 +93,10 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -92,6 +93,10 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWrapper {
public static final Enumeration<String> EMPTY_HEADER_VALUES =
Collections.enumeration(Collections.<String>emptyList());
private final String scheme;
private final boolean secure;
@ -100,9 +105,9 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -100,9 +105,9 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
private final int port;
private final String portInUrl;
private final StringBuffer requestUrl;
private final Map<String, List<String>> headers = new LinkedHashMap<String, List<String>>();
private final Map<String, List<String>> headers;
public ForwardedHeaderRequestWrapper(HttpServletRequest request) {
@ -116,51 +121,35 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -116,51 +121,35 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
this.secure = "https".equals(scheme);
this.host = uriComponents.getHost();
this.port = (port == -1 ? (this.secure ? 443 : 80) : port);
this.portInUrl = (port == -1 ? "" : ":" + port);
this.requestUrl = initRequestUrl(this.scheme, this.host, port, request.getRequestURI());
this.headers = initHeaders(request);
}
private static StringBuffer initRequestUrl(String scheme, String host, int port, String path) {
StringBuffer sb = new StringBuffer();
sb.append(scheme).append("://").append(host);
sb.append(port == -1 ? "" : ":" + port);
sb.append(path);
return sb;
}
/**
* Copy the headers excluding any {@link #FORWARDED_HEADER_NAMES}.
*/
private static Map<String, List<String>> initHeaders(HttpServletRequest request) {
Map<String, List<String>> headers = new LinkedHashMap<String, List<String>>();
Enumeration<String> headerNames = request.getHeaderNames();
while (headerNames.hasMoreElements()) {
String name = headerNames.nextElement();
this.headers.put(name, Collections.list(request.getHeaders(name)));
headers.put(name, Collections.list(request.getHeaders(name)));
}
for (String name : FORWARDED_HEADER_NAMES) {
this.headers.remove(name);
headers.remove(name);
}
return headers;
}
@Override
public String getHeader(String name) {
Map.Entry<String, List<String>> header = getHeaderEntry(name);
if (header == null || header.getValue() == null || header.getValue().isEmpty()) {
return null;
}
return header.getValue().get(0);
}
protected Map.Entry<String, List<String>> getHeaderEntry(String name) {
for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) {
if (entry.getKey().equalsIgnoreCase(name)) {
return entry;
}
}
return null;
}
@Override
public Enumeration<String> getHeaderNames() {
return Collections.enumeration(this.headers.keySet());
}
@Override
public Enumeration<String> getHeaders(String name) {
Map.Entry<String, List<String>> header = getHeaderEntry(name);
if (header == null || header.getValue() == null) {
return Collections.enumeration(Collections.<String>emptyList());
}
return Collections.enumeration(header.getValue());
}
@Override
public String getScheme() {
return this.scheme;
@ -183,10 +172,26 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -183,10 +172,26 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
@Override
public StringBuffer getRequestURL() {
StringBuffer sb = new StringBuffer();
sb.append(this.scheme).append("://").append(this.host).append(this.portInUrl);
sb.append(getRequestURI());
return sb;
return this.requestUrl;
}
// Override header accessors in order to not expose forwarded headers
@Override
public String getHeader(String name) {
List<String> value = this.headers.get(name);
return (CollectionUtils.isEmpty(value) ? null : value.get(0));
}
@Override
public Enumeration<String> getHeaders(String name) {
List<String> value = this.headers.get(name);
return (CollectionUtils.isEmpty(value) ? EMPTY_HEADER_VALUES : Collections.enumeration(value));
}
@Override
public Enumeration<String> getHeaderNames() {
return Collections.enumeration(this.headers.keySet());
}
}

5
spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java

@ -53,7 +53,7 @@ public class ForwardedHeaderFilterTests { @@ -53,7 +53,7 @@ public class ForwardedHeaderFilterTests {
}
@Test
public void xForwardedHeaders() throws Exception {
public void forwardedRequest() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setScheme("http");
request.setServerName("localhost");
@ -65,10 +65,9 @@ public class ForwardedHeaderFilterTests { @@ -65,10 +65,9 @@ public class ForwardedHeaderFilterTests {
request.addHeader("foo", "bar");
MockFilterChain chain = new MockFilterChain(new HttpServlet() {});
this.filter.doFilter(request, new MockHttpServletResponse(), chain);
HttpServletRequest actual = (HttpServletRequest) chain.getRequest();
assertEquals("https://84.198.58.199/mvc-showcase", actual.getRequestURL().toString());
assertEquals("https", actual.getScheme());
assertEquals("84.198.58.199", actual.getServerName());

Loading…
Cancel
Save