Browse Source

Optimize header removal in ForwardedHeaderFilter

The current implementation suggests that the request's headers are not
expected to change. Hence, it's not necessary to copy them.
Furthermore, it might be costly to do so if there are many headers.
Instead, cache only the request's header names for method getHeaderNames.

Methods getHeader and getHeaders delegate to the respective methods of
request if the header name is not in FORWARDED_HEADER_NAMES. Otherwise,
they return null or an empty Enumeration respectively.

See gh-27466
pull/27735/head
Daniel Le 3 years ago committed by Brian Clozel
parent
commit
6605953eb5
  1. 42
      spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java

42
spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java

@ -20,9 +20,7 @@ import java.io.IOException; @@ -20,9 +20,7 @@ import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
@ -37,7 +35,6 @@ import org.springframework.http.HttpStatus; @@ -37,7 +35,6 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedCaseInsensitiveMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents;
@ -169,23 +166,26 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -169,23 +166,26 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
*/
private static class ForwardedHeaderRemovingRequest extends HttpServletRequestWrapper {
private final Map<String, List<String>> headers;
private final Set<String> headerNames;
public ForwardedHeaderRemovingRequest(HttpServletRequest request) {
super(request);
this.headers = initHeaders(request);
this.headerNames = headerNames(request);
}
private static Map<String, List<String>> initHeaders(HttpServletRequest request) {
Map<String, List<String>> headers = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH);
Enumeration<String> names = request.getHeaderNames();
private static Set<String> headerNames(HttpServletRequest request) {
final var headerNames = Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(Locale.ENGLISH));
final var names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
if (!FORWARDED_HEADER_NAMES.contains(name)) {
headers.put(name, Collections.list(request.getHeaders(name)));
}
final var name = names.nextElement();
headerNames.add(name);
}
return headers;
headerNames.removeAll(FORWARDED_HEADER_NAMES);
return Collections.unmodifiableSet(headerNames);
}
// Override header accessors to not expose forwarded headers
@ -193,19 +193,25 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -193,19 +193,25 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
@Override
@Nullable
public String getHeader(String name) {
List<String> value = this.headers.get(name);
return (CollectionUtils.isEmpty(value) ? null : value.get(0));
if (FORWARDED_HEADER_NAMES.contains(name)) {
return null;
}
return super.getHeader(name);
}
@Override
public Enumeration<String> getHeaders(String name) {
List<String> value = this.headers.get(name);
return (Collections.enumeration(value != null ? value : Collections.emptySet()));
if (FORWARDED_HEADER_NAMES.contains(name)) {
return Collections.emptyEnumeration();
}
return super.getHeaders(name);
}
@Override
public Enumeration<String> getHeaderNames() {
return Collections.enumeration(this.headers.keySet());
return Collections.enumeration(this.headerNames);
}
}

Loading…
Cancel
Save