@ -31,6 +31,7 @@ import javax.servlet.http.HttpServletResponse;
@@ -31,6 +31,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpRequest ;
import org.springframework.http.server.ServletServerHttpRequest ;
import org.springframework.util.CollectionUtils ;
import org.springframework.web.util.UriComponents ;
import org.springframework.web.util.UriComponentsBuilder ;
@ -92,6 +93,10 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
@@ -92,6 +93,10 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWrapper {
public static final Enumeration < String > EMPTY_HEADER_VALUES =
Collections . enumeration ( Collections . < String > emptyList ( ) ) ;
private final String scheme ;
private final boolean secure ;
@ -100,9 +105,9 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
@@ -100,9 +105,9 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
private final int port ;
private final String po rtIn Url;
private final StringBuffer reques tUrl ;
private final Map < String , List < String > > headers = new LinkedHashMap < String , List < String > > ( ) ;
private final Map < String , List < String > > headers ;
public ForwardedHeaderRequestWrapper ( HttpServletRequest request ) {
@ -116,51 +121,35 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
@@ -116,51 +121,35 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
this . secure = "https" . equals ( scheme ) ;
this . host = uriComponents . getHost ( ) ;
this . port = ( port = = - 1 ? ( this . secure ? 443 : 80 ) : port ) ;
this . portInUrl = ( port = = - 1 ? "" : ":" + port ) ;
this . requestUrl = initRequestUrl ( this . scheme , this . host , port , request . getRequestURI ( ) ) ;
this . headers = initHeaders ( request ) ;
}
private static StringBuffer initRequestUrl ( String scheme , String host , int port , String path ) {
StringBuffer sb = new StringBuffer ( ) ;
sb . append ( scheme ) . append ( "://" ) . append ( host ) ;
sb . append ( port = = - 1 ? "" : ":" + port ) ;
sb . append ( path ) ;
return sb ;
}
/ * *
* Copy the headers excluding any { @link # FORWARDED_HEADER_NAMES } .
* /
private static Map < String , List < String > > initHeaders ( HttpServletRequest request ) {
Map < String , List < String > > headers = new LinkedHashMap < String , List < String > > ( ) ;
Enumeration < String > headerNames = request . getHeaderNames ( ) ;
while ( headerNames . hasMoreElements ( ) ) {
String name = headerNames . nextElement ( ) ;
this . headers . put ( name , Collections . list ( request . getHeaders ( name ) ) ) ;
headers . put ( name , Collections . list ( request . getHeaders ( name ) ) ) ;
}
for ( String name : FORWARDED_HEADER_NAMES ) {
this . headers . remove ( name ) ;
headers . remove ( name ) ;
}
return headers ;
}
@Override
public String getHeader ( String name ) {
Map . Entry < String , List < String > > header = getHeaderEntry ( name ) ;
if ( header = = null | | header . getValue ( ) = = null | | header . getValue ( ) . isEmpty ( ) ) {
return null ;
}
return header . getValue ( ) . get ( 0 ) ;
}
protected Map . Entry < String , List < String > > getHeaderEntry ( String name ) {
for ( Map . Entry < String , List < String > > entry : this . headers . entrySet ( ) ) {
if ( entry . getKey ( ) . equalsIgnoreCase ( name ) ) {
return entry ;
}
}
return null ;
}
@Override
public Enumeration < String > getHeaderNames ( ) {
return Collections . enumeration ( this . headers . keySet ( ) ) ;
}
@Override
public Enumeration < String > getHeaders ( String name ) {
Map . Entry < String , List < String > > header = getHeaderEntry ( name ) ;
if ( header = = null | | header . getValue ( ) = = null ) {
return Collections . enumeration ( Collections . < String > emptyList ( ) ) ;
}
return Collections . enumeration ( header . getValue ( ) ) ;
}
@Override
public String getScheme ( ) {
return this . scheme ;
@ -183,10 +172,26 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
@@ -183,10 +172,26 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
@Override
public StringBuffer getRequestURL ( ) {
StringBuffer sb = new StringBuffer ( ) ;
sb . append ( this . scheme ) . append ( "://" ) . append ( this . host ) . append ( this . portInUrl ) ;
sb . append ( getRequestURI ( ) ) ;
return sb ;
return this . requestUrl ;
}
// Override header accessors in order to not expose forwarded headers
@Override
public String getHeader ( String name ) {
List < String > value = this . headers . get ( name ) ;
return ( CollectionUtils . isEmpty ( value ) ? null : value . get ( 0 ) ) ;
}
@Override
public Enumeration < String > getHeaders ( String name ) {
List < String > value = this . headers . get ( name ) ;
return ( CollectionUtils . isEmpty ( value ) ? EMPTY_HEADER_VALUES : Collections . enumeration ( value ) ) ;
}
@Override
public Enumeration < String > getHeaderNames ( ) {
return Collections . enumeration ( this . headers . keySet ( ) ) ;
}
}