Browse Source

Header predicate option in AbstractRequestLoggingFilter

Closes gh-22244
pull/26676/head
Rossen Stoyanchev 6 years ago
parent
commit
57a67a3c06
  1. 40
      spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java
  2. 74
      spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java

40
spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,12 +18,15 @@ package org.springframework.web.filter; @@ -18,12 +18,15 @@ package org.springframework.web.filter;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.Enumeration;
import java.util.function.Predicate;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
@ -95,6 +98,9 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter @@ -95,6 +98,9 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter
private boolean includePayload = false;
@Nullable
private Predicate<String> headerPredicate;
private int maxPayloadLength = DEFAULT_MAX_PAYLOAD_LENGTH;
private String beforeMessagePrefix = DEFAULT_BEFORE_MESSAGE_PREFIX;
@ -176,6 +182,26 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter @@ -176,6 +182,26 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter
return this.includePayload;
}
/**
* Configure a predicate for selecting which headers should be logged if
* {@link #setIncludeHeaders(boolean)} is set to {@code true}.
* <p>By default this is not set in which case all headers are logged.
* @param headerPredicate the predicate to use
* @since 5.2
*/
public void setHeaderPredicate(@Nullable Predicate<String> headerPredicate) {
this.headerPredicate = headerPredicate;
}
/**
* The configured {@link #setHeaderPredicate(Predicate) headerPredicate}.
* @since 5.2
*/
@Nullable
public Predicate<String> getHeaderPredicate() {
return this.headerPredicate;
}
/**
* Set the maximum length of the payload body to be included in the log message.
* Default is 50 characters.
@ -320,7 +346,17 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter @@ -320,7 +346,17 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter
}
if (isIncludeHeaders()) {
msg.append(";headers=").append(new ServletServerHttpRequest(request).getHeaders());
HttpHeaders headers = new ServletServerHttpRequest(request).getHeaders();
if (getHeaderPredicate() != null) {
Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String header = names.nextElement();
if (!getHeaderPredicate().test(header)) {
headers.set(header, "masked");
}
}
}
msg.append(";headers=").append(headers);
}
if (isIncludePayload()) {

74
spring-web/src/test/java/org/springframework/web/filter/RequestLoggingFilterTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@ @@ -17,6 +17,7 @@
package org.springframework.web.filter;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
@ -100,6 +101,27 @@ public class RequestLoggingFilterTests { @@ -100,6 +101,27 @@ public class RequestLoggingFilterTests {
assertTrue(filter.afterRequestMessage.contains("[uri=/hotels]"));
}
@Test
public void headers() throws Exception {
final MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels");
request.setContentType("application/json");
request.addHeader("token", "123");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = new NoOpFilterChain();
filter.setIncludeHeaders(true);
filter.setHeaderPredicate(name -> !name.equalsIgnoreCase("token"));
filter.doFilter(request, response, filterChain);
assertNotNull(filter.beforeRequestMessage);
assertEquals("Before request [uri=/hotels;headers=[Content-Type:\"application/json\", token:\"masked\"]]",
filter.beforeRequestMessage);
assertNotNull(filter.afterRequestMessage);
assertEquals("After request [uri=/hotels;headers=[Content-Type:\"application/json\", token:\"masked\"]]",
filter.afterRequestMessage);
}
@Test
public void payloadInputStream() throws Exception {
filter.setIncludePayload(true);
@ -107,17 +129,13 @@ public class RequestLoggingFilterTests { @@ -107,17 +129,13 @@ public class RequestLoggingFilterTests {
final MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels");
MockHttpServletResponse response = new MockHttpServletResponse();
final byte[] requestBody = "Hello World".getBytes("UTF-8");
final byte[] requestBody = "Hello World".getBytes(StandardCharsets.UTF_8);
request.setContent(requestBody);
FilterChain filterChain = new FilterChain() {
@Override
public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse)
throws IOException, ServletException {
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
byte[] buf = FileCopyUtils.copyToByteArray(filterRequest.getInputStream());
assertArrayEquals(requestBody, buf);
}
FilterChain filterChain = (filterRequest, filterResponse) -> {
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
byte[] buf = FileCopyUtils.copyToByteArray(filterRequest.getInputStream());
assertArrayEquals(requestBody, buf);
};
filter.doFilter(request, response, filterChain);
@ -134,16 +152,12 @@ public class RequestLoggingFilterTests { @@ -134,16 +152,12 @@ public class RequestLoggingFilterTests {
MockHttpServletResponse response = new MockHttpServletResponse();
final String requestBody = "Hello World";
request.setContent(requestBody.getBytes("UTF-8"));
FilterChain filterChain = new FilterChain() {
@Override
public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse)
throws IOException, ServletException {
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
String buf = FileCopyUtils.copyToString(filterRequest.getReader());
assertEquals(requestBody, buf);
}
request.setContent(requestBody.getBytes(StandardCharsets.UTF_8));
FilterChain filterChain = (filterRequest, filterResponse) -> {
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
String buf = FileCopyUtils.copyToString(filterRequest.getReader());
assertEquals(requestBody, buf);
};
filter.doFilter(request, response, filterChain);
@ -160,20 +174,16 @@ public class RequestLoggingFilterTests { @@ -160,20 +174,16 @@ public class RequestLoggingFilterTests {
final MockHttpServletRequest request = new MockHttpServletRequest("POST", "/hotels");
MockHttpServletResponse response = new MockHttpServletResponse();
final byte[] requestBody = "Hello World".getBytes("UTF-8");
final byte[] requestBody = "Hello World".getBytes(StandardCharsets.UTF_8);
request.setContent(requestBody);
FilterChain filterChain = new FilterChain() {
@Override
public void doFilter(ServletRequest filterRequest, ServletResponse filterResponse)
throws IOException, ServletException {
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
byte[] buf = FileCopyUtils.copyToByteArray(filterRequest.getInputStream());
assertArrayEquals(requestBody, buf);
ContentCachingRequestWrapper wrapper =
WebUtils.getNativeRequest(filterRequest, ContentCachingRequestWrapper.class);
assertArrayEquals("Hel".getBytes("UTF-8"), wrapper.getContentAsByteArray());
}
FilterChain filterChain = (filterRequest, filterResponse) -> {
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
byte[] buf = FileCopyUtils.copyToByteArray(filterRequest.getInputStream());
assertArrayEquals(requestBody, buf);
ContentCachingRequestWrapper wrapper =
WebUtils.getNativeRequest(filterRequest, ContentCachingRequestWrapper.class);
assertArrayEquals("Hel".getBytes(StandardCharsets.UTF_8), wrapper.getContentAsByteArray());
};
filter.doFilter(request, response, filterChain);

Loading…
Cancel
Save