diff --git a/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java b/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java index c498645258..b6aa4bc2d4 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,12 +18,15 @@ package org.springframework.web.filter; import java.io.IOException; import java.io.UnsupportedEncodingException; +import java.util.Enumeration; +import java.util.function.Predicate; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; +import org.springframework.http.HttpHeaders; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -95,6 +98,9 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter private boolean includePayload = false; + @Nullable + private Predicate headerPredicate; + private int maxPayloadLength = DEFAULT_MAX_PAYLOAD_LENGTH; private String beforeMessagePrefix = DEFAULT_BEFORE_MESSAGE_PREFIX; @@ -176,6 +182,26 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter return this.includePayload; } + /** + * Configure a predicate for selecting which headers should be logged if + * {@link #setIncludeHeaders(boolean)} is set to {@code true}. + *

By default this is not set in which case all headers are logged. + * @param headerPredicate the predicate to use + * @since 5.2 + */ + public void setHeaderPredicate(@Nullable Predicate headerPredicate) { + this.headerPredicate = headerPredicate; + } + + /** + * The configured {@link #setHeaderPredicate(Predicate) headerPredicate}. + * @since 5.2 + */ + @Nullable + public Predicate getHeaderPredicate() { + return this.headerPredicate; + } + /** * Set the maximum length of the payload body to be included in the log message. * Default is 50 characters. @@ -320,7 +346,17 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter } if (isIncludeHeaders()) { - msg.append(";headers=").append(new ServletServerHttpRequest(request).getHeaders()); + HttpHeaders headers = new ServletServerHttpRequest(request).getHeaders(); + if (getHeaderPredicate() != null) { + Enumeration names = request.getHeaderNames(); + while (names.hasMoreElements()) { + String header = names.nextElement(); + if (!getHeaderPredicate().test(header)) { + headers.set(header, "masked"); + } + } + } + msg.append(";headers=").append(headers); } if (isIncludePayload()) { diff --git a/spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java index 57777f2779..87a8b302b1 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.web.filter; import java.io.IOException; +import java.nio.charset.StandardCharsets; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.ServletRequest; @@ -100,6 +101,27 @@ public class RequestLoggingFilterTests { assertTrue(filter.afterRequestMessage.contains("[uri=/hotels]")); } + @Test + public void headers() throws Exception { + final MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels"); + request.setContentType("application/json"); + request.addHeader("token", "123"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = new NoOpFilterChain(); + filter.setIncludeHeaders(true); + filter.setHeaderPredicate(name -> !name.equalsIgnoreCase("token")); + filter.doFilter(request, response, filterChain); + + assertNotNull(filter.beforeRequestMessage); + assertEquals("Before request [uri=/hotels;headers=[Content-Type:\"application/json\", token:\"masked\"]]", + filter.beforeRequestMessage); + + assertNotNull(filter.afterRequestMessage); + assertEquals("After request [uri=/hotels;headers=[Content-Type:\"application/json\", token:\"masked\"]]", + filter.afterRequestMessage); + } + @Test public void payloadInputStream() throws Exception { filter.setIncludePayload(true); @@ -107,17 +129,13 @@ public class RequestLoggingFilterTests { final MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels"); MockHttpServletResponse response = new MockHttpServletResponse(); - final byte[] requestBody = "Hello World".getBytes("UTF-8"); + final byte[] requestBody = "Hello World".getBytes(StandardCharsets.UTF_8); request.setContent(requestBody); - FilterChain filterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse) - throws IOException, ServletException { - ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); - byte[] buf = FileCopyUtils.copyToByteArray(filterRequest.getInputStream()); - assertArrayEquals(requestBody, buf); - } + FilterChain filterChain = (filterRequest, filterResponse) -> { + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + byte[] buf = FileCopyUtils.copyToByteArray(filterRequest.getInputStream()); + assertArrayEquals(requestBody, buf); }; filter.doFilter(request, response, filterChain); @@ -134,16 +152,12 @@ public class RequestLoggingFilterTests { MockHttpServletResponse response = new MockHttpServletResponse(); final String requestBody = "Hello World"; - request.setContent(requestBody.getBytes("UTF-8")); - - FilterChain filterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse) - throws IOException, ServletException { - ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); - String buf = FileCopyUtils.copyToString(filterRequest.getReader()); - assertEquals(requestBody, buf); - } + request.setContent(requestBody.getBytes(StandardCharsets.UTF_8)); + + FilterChain filterChain = (filterRequest, filterResponse) -> { + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + String buf = FileCopyUtils.copyToString(filterRequest.getReader()); + assertEquals(requestBody, buf); }; filter.doFilter(request, response, filterChain); @@ -160,20 +174,16 @@ public class RequestLoggingFilterTests { final MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels"); MockHttpServletResponse response = new MockHttpServletResponse(); - final byte[] requestBody = "Hello World".getBytes("UTF-8"); + final byte[] requestBody = "Hello World".getBytes(StandardCharsets.UTF_8); request.setContent(requestBody); - FilterChain filterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse) - throws IOException, ServletException { - ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); - byte[] buf = FileCopyUtils.copyToByteArray(filterRequest.getInputStream()); - assertArrayEquals(requestBody, buf); - ContentCachingRequestWrapper wrapper = - WebUtils.getNativeRequest(filterRequest, ContentCachingRequestWrapper.class); - assertArrayEquals("Hel".getBytes("UTF-8"), wrapper.getContentAsByteArray()); - } + FilterChain filterChain = (filterRequest, filterResponse) -> { + ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + byte[] buf = FileCopyUtils.copyToByteArray(filterRequest.getInputStream()); + assertArrayEquals(requestBody, buf); + ContentCachingRequestWrapper wrapper = + WebUtils.getNativeRequest(filterRequest, ContentCachingRequestWrapper.class); + assertArrayEquals("Hel".getBytes(StandardCharsets.UTF_8), wrapper.getContentAsByteArray()); }; filter.doFilter(request, response, filterChain);