diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index 27db5087a2..9d6bb2eed6 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -20,9 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.util.Collections; import java.util.Enumeration; -import java.util.List; import java.util.Locale; -import java.util.Map; import java.util.Set; import java.util.function.Supplier; @@ -37,7 +35,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.lang.Nullable; -import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.StringUtils; import org.springframework.web.util.UriComponents; @@ -169,23 +166,26 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { */ private static class ForwardedHeaderRemovingRequest extends HttpServletRequestWrapper { - private final Map> headers; + private final Set headerNames; public ForwardedHeaderRemovingRequest(HttpServletRequest request) { super(request); - this.headers = initHeaders(request); + + this.headerNames = headerNames(request); } - private static Map> initHeaders(HttpServletRequest request) { - Map> headers = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH); - Enumeration names = request.getHeaderNames(); + private static Set headerNames(HttpServletRequest request) { + final var headerNames = Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(Locale.ENGLISH)); + final var names = request.getHeaderNames(); + while (names.hasMoreElements()) { - String name = names.nextElement(); - if (!FORWARDED_HEADER_NAMES.contains(name)) { - headers.put(name, Collections.list(request.getHeaders(name))); - } + final var name = names.nextElement(); + headerNames.add(name); } - return headers; + + headerNames.removeAll(FORWARDED_HEADER_NAMES); + + return Collections.unmodifiableSet(headerNames); } // Override header accessors to not expose forwarded headers @@ -193,19 +193,25 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @Override @Nullable public String getHeader(String name) { - List value = this.headers.get(name); - return (CollectionUtils.isEmpty(value) ? null : value.get(0)); + if (FORWARDED_HEADER_NAMES.contains(name)) { + return null; + } + + return super.getHeader(name); } @Override public Enumeration getHeaders(String name) { - List value = this.headers.get(name); - return (Collections.enumeration(value != null ? value : Collections.emptySet())); + if (FORWARDED_HEADER_NAMES.contains(name)) { + return Collections.emptyEnumeration(); + } + + return super.getHeaders(name); } @Override public Enumeration getHeaderNames() { - return Collections.enumeration(this.headers.keySet()); + return Collections.enumeration(this.headerNames); } }