Browse Source

Revise HandlerMappingIntrospector caching

Expose methods to set and reset cache to use from a Filter instead
of a method to create such a Filter. Also use cached results only
if they match by dispatcher type and requestURI.

See gh-31588
6.0.x
rstoyanchev 10 months ago
parent
commit
a4e3af5cbe
  1. 162
      spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java
  2. 196
      spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java

162
spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java

@ -27,7 +27,7 @@ import java.util.Properties;
import java.util.function.BiFunction; import java.util.function.BiFunction;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import jakarta.servlet.Filter; import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletException; import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest; import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
@ -68,25 +68,27 @@ import org.springframework.web.util.pattern.PathPatternParser;
* request. * request.
* </ul> * </ul>
* *
* <p><strong>Note:</strong> This is primarily an SPI to allow Spring Security * <p>Note that this is primarily an SPI to allow Spring Security
* to align its pattern matching with the same pattern matching that would be * to align its pattern matching with the same pattern matching that would be
* used in Spring MVC for a given request, in order to avoid security issues. * used in Spring MVC for a given request, in order to avoid security issues.
* Use of this introspector should be avoided for other purposes because it * Use of this introspector should be avoided for other purposes because it
* incurs the overhead of resolving the handler for a request. * incurs the overhead of resolving the handler for a request.
* *
* <p>Alternative security filter solutions that also rely on
* {@link HandlerMappingIntrospector} should consider adding an additional
* {@link jakarta.servlet.Filter} that invokes
* {@link #setCache(HttpServletRequest)} and {@link #resetCache(ServletRequest, CachedResult)}
* before and after delegating to the rest of the chain. Such a Filter should
* process all dispatcher types and should be ordered ahead of security filters.
*
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @since 4.3.1 * @since 4.3.1
*/ */
public class HandlerMappingIntrospector public class HandlerMappingIntrospector
implements CorsConfigurationSource, ApplicationContextAware, InitializingBean { implements CorsConfigurationSource, ApplicationContextAware, InitializingBean {
static final String MAPPING_ATTRIBUTE = private static final String CACHED_RESULT_ATTRIBUTE =
HandlerMappingIntrospector.class.getName() + ".HandlerMapping"; HandlerMappingIntrospector.class.getName() + ".CachedResult";
static final String CORS_CONFIG_ATTRIBUTE =
HandlerMappingIntrospector.class.getName() + ".CorsConfig";
private static final CorsConfiguration NO_CORS_CONFIG = new CorsConfiguration();
@Nullable @Nullable
@ -166,55 +168,43 @@ public class HandlerMappingIntrospector
/** /**
* Return Filter that performs lookups, caches the results in request attributes, * Perform a lookup and save the {@link CachedResult} as a request attribute.
* and clears the attributes after the filter chain returns. * This method can be invoked from a filter before subsequent calls to
* {@link #getMatchableHandlerMapping(HttpServletRequest)} and
* {@link #getCorsConfiguration(HttpServletRequest)} to avoid repeated lookups.
* @param request the current request
* @return the previous {@link CachedResult}, if there is one from a parent dispatch
* @throws ServletException thrown the lookup fails for any reason
* @since 6.0.14 * @since 6.0.14
*/ */
public Filter createCacheFilter() { @Nullable
return (request, response, chain) -> { public CachedResult setCache(HttpServletRequest request) throws ServletException {
MatchableHandlerMapping previousMapping = getCachedMapping(request); CachedResult previous = getAttribute(request);
CorsConfiguration previousCorsConfig = getCachedCorsConfiguration(request); if (previous == null || !previous.matches(request)) {
try { try {
HttpServletRequest wrappedRequest = new AttributesPreservingRequest((HttpServletRequest) request); HttpServletRequest wrapped = new AttributesPreservingRequest(request);
doWithHandlerMapping(wrappedRequest, false, (mapping, executionChain) -> { CachedResult cachedResult = doWithHandlerMapping(wrapped, false, (mapping, executionChain) -> {
MatchableHandlerMapping matchableMapping = createMatchableHandlerMapping(mapping, wrappedRequest); MatchableHandlerMapping matchableMapping = createMatchableHandlerMapping(mapping, wrapped);
CorsConfiguration corsConfig = getCorsConfiguration(wrappedRequest, executionChain); CorsConfiguration corsConfig = getCorsConfiguration(wrapped, executionChain);
setCache(request, matchableMapping, corsConfig); return new CachedResult(request, matchableMapping, corsConfig);
return null;
}); });
chain.doFilter(request, response); request.setAttribute(CACHED_RESULT_ATTRIBUTE,
cachedResult != null ? cachedResult : new CachedResult(request, null, null));
} }
catch (Exception ex) { catch (Throwable ex) {
throw new ServletException("HandlerMapping introspection failed", ex); throw new ServletException("HandlerMapping introspection failed", ex);
} }
finally { }
setCache(request, previousMapping, previousCorsConfig); return previous;
}
};
}
@Nullable
private static MatchableHandlerMapping getCachedMapping(ServletRequest request) {
return (MatchableHandlerMapping) request.getAttribute(MAPPING_ATTRIBUTE);
}
@Nullable
private static CorsConfiguration getCachedCorsConfiguration(ServletRequest request) {
return (CorsConfiguration) request.getAttribute(CORS_CONFIG_ATTRIBUTE);
} }
private static void setCache( /**
ServletRequest request, @Nullable MatchableHandlerMapping mapping, * Restore a previous {@link CachedResult}. This method can be invoked from
@Nullable CorsConfiguration corsConfig) { * a filter after delegating to the rest of the chain.
* @since 6.0.14
if (mapping != null) { */
request.setAttribute(MAPPING_ATTRIBUTE, mapping); public void resetCache(ServletRequest request, @Nullable CachedResult cachedResult) {
request.setAttribute(CORS_CONFIG_ATTRIBUTE, (corsConfig != null ? corsConfig : NO_CORS_CONFIG)); request.setAttribute(CACHED_RESULT_ATTRIBUTE, cachedResult);
}
else {
request.removeAttribute(MAPPING_ATTRIBUTE);
request.removeAttribute(CORS_CONFIG_ATTRIBUTE);
}
} }
/** /**
@ -228,9 +218,9 @@ public class HandlerMappingIntrospector
*/ */
@Nullable @Nullable
public MatchableHandlerMapping getMatchableHandlerMapping(HttpServletRequest request) throws Exception { public MatchableHandlerMapping getMatchableHandlerMapping(HttpServletRequest request) throws Exception {
MatchableHandlerMapping cachedMapping = getCachedMapping(request); CachedResult cachedResult = getCachedResultFor(request);
if (cachedMapping != null) { if (cachedResult != null) {
return cachedMapping; return cachedResult.getHandlerMapping();
} }
HttpServletRequest requestToUse = new AttributesPreservingRequest(request); HttpServletRequest requestToUse = new AttributesPreservingRequest(request);
return doWithHandlerMapping(requestToUse, false, return doWithHandlerMapping(requestToUse, false,
@ -255,9 +245,9 @@ public class HandlerMappingIntrospector
@Override @Override
@Nullable @Nullable
public CorsConfiguration getCorsConfiguration(HttpServletRequest request) { public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
CorsConfiguration cachedCorsConfiguration = getCachedCorsConfiguration(request); CachedResult cachedResult = getCachedResultFor(request);
if (cachedCorsConfiguration != null) { if (cachedResult != null) {
return (cachedCorsConfiguration != NO_CORS_CONFIG ? cachedCorsConfiguration : null); return cachedResult.getCorsConfig();
} }
try { try {
boolean ignoreException = true; boolean ignoreException = true;
@ -322,6 +312,68 @@ public class HandlerMappingIntrospector
return null; return null;
} }
/**
* Return a {@link CachedResult} that matches the given request.
*/
@Nullable
private CachedResult getCachedResultFor(HttpServletRequest request) {
CachedResult result = getAttribute(request);
return (result != null && result.matches(request) ? result : null);
}
@Nullable
private static CachedResult getAttribute(HttpServletRequest request) {
return (CachedResult) request.getAttribute(CACHED_RESULT_ATTRIBUTE);
}
/**
* Container for a {@link MatchableHandlerMapping} and {@link CorsConfiguration}
* for a given request identified by dispatcher type and requestURI.
* @since 6.0.14
*/
public final static class CachedResult {
private final DispatcherType dispatcherType;
private final String requestURI;
@Nullable
private final MatchableHandlerMapping handlerMapping;
@Nullable
private final CorsConfiguration corsConfig;
private CachedResult(HttpServletRequest request,
@Nullable MatchableHandlerMapping mapping, @Nullable CorsConfiguration config) {
this.dispatcherType = request.getDispatcherType();
this.requestURI = request.getRequestURI();
this.handlerMapping = mapping;
this.corsConfig = config;
}
public boolean matches(HttpServletRequest request) {
return (this.dispatcherType.equals(request.getDispatcherType()) &&
this.requestURI.matches(request.getRequestURI()));
}
@Nullable
public MatchableHandlerMapping getHandlerMapping() {
return this.handlerMapping;
}
@Nullable
public CorsConfiguration getCorsConfig() {
return this.corsConfig;
}
@Override
public String toString() {
return "CacheValue " + this.dispatcherType + " '" + this.requestURI + "'";
}
}
/** /**
* Request wrapper that buffers request attributes in order protect the * Request wrapper that buffers request attributes in order protect the

196
spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java

@ -23,7 +23,10 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import jakarta.servlet.Filter; import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException; import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
@ -42,12 +45,14 @@ import org.springframework.web.context.support.AnnotationConfigWebApplicationCon
import org.springframework.web.context.support.GenericWebApplicationContext; import org.springframework.web.context.support.GenericWebApplicationContext;
import org.springframework.web.context.support.StaticWebApplicationContext; import org.springframework.web.context.support.StaticWebApplicationContext;
import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.RouterFunctions; import org.springframework.web.servlet.function.RouterFunctions;
import org.springframework.web.servlet.function.ServerResponse; import org.springframework.web.servlet.function.ServerResponse;
import org.springframework.web.servlet.function.support.RouterFunctionMapping; import org.springframework.web.servlet.function.support.RouterFunctionMapping;
import org.springframework.web.servlet.handler.HandlerMappingIntrospector.CachedResult;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping; import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;
import org.springframework.web.testfixture.servlet.MockFilterChain; import org.springframework.web.testfixture.servlet.MockFilterChain;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
@ -202,68 +207,67 @@ public class HandlerMappingIntrospectorTests {
@Test @Test
void cacheFilter() throws Exception { void cacheFilter() throws Exception {
testCacheFilter(new MockHttpServletRequest()); CorsConfiguration corsConfig = new CorsConfiguration();
} TestMatchableHandlerMapping mapping = new TestMatchableHandlerMapping();
mapping.registerHandler("/test", new TestHandler(corsConfig));
@Test HandlerMappingIntrospector introspector = initIntrospector(mapping);
void cacheFilterRestoresPreviousValues() throws Exception {
TestMatchableHandlerMapping previousMapping = new TestMatchableHandlerMapping();
CorsConfiguration previousCorsConfig = new CorsConfiguration();
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest("GET", "/test");
request.setAttribute(HandlerMappingIntrospector.MAPPING_ATTRIBUTE, previousMapping); MockHttpServletResponse response = new MockHttpServletResponse();
request.setAttribute(HandlerMappingIntrospector.CORS_CONFIG_ATTRIBUTE, previousCorsConfig);
MockFilterChain filterChain = new MockFilterChain(
new TestServlet(), new CacheResultFilter(introspector), new AuthFilter(introspector, corsConfig));
testCacheFilter(request); filterChain.doFilter(request, response);
assertThat(previousMapping.getInvocationCount()).isEqualTo(0); assertThat(response.getContentAsString()).isEqualTo("Success");
assertThat(request.getAttribute(HandlerMappingIntrospector.MAPPING_ATTRIBUTE)).isSameAs(previousMapping); assertThat(mapping.getInvocationCount()).isEqualTo(1);
assertThat(request.getAttribute(HandlerMappingIntrospector.CORS_CONFIG_ATTRIBUTE)).isSameAs(previousCorsConfig); assertThat(mapping.getMatchCount()).isEqualTo(1);
} }
private void testCacheFilter(MockHttpServletRequest request) throws IOException, ServletException { @Test
TestMatchableHandlerMapping mapping = new TestMatchableHandlerMapping(); void cacheFilterWithNestedDispatch() throws Exception {
StaticWebApplicationContext context = new StaticWebApplicationContext(); CorsConfiguration corsConfig1 = new CorsConfiguration();
context.registerBean(TestMatchableHandlerMapping.class, () -> mapping); CorsConfiguration corsConfig2 = new CorsConfiguration();
context.refresh();
HandlerMappingIntrospector introspector = initIntrospector(context); TestMatchableHandlerMapping mapping1 = new TestMatchableHandlerMapping();
MockHttpServletResponse response = new MockHttpServletResponse(); TestMatchableHandlerMapping mapping2 = new TestMatchableHandlerMapping();
Filter filter = (req, res, chain) -> { mapping1.registerHandler("/1", new TestHandler(corsConfig1));
try { mapping2.registerHandler("/2", new TestHandler(corsConfig2));
for (int i = 0; i < 10; i++) {
introspector.getMatchableHandlerMapping((HttpServletRequest) req);
introspector.getCorsConfiguration((HttpServletRequest) req);
}
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
chain.doFilter(req, res);
};
HttpServlet servlet = new HttpServlet() { HandlerMappingIntrospector introspector = initIntrospector(mapping1, mapping2);
@Override MockFilterChain filterChain = new MockFilterChain(
protected void service(HttpServletRequest req, HttpServletResponse res) { new TestServlet(),
try { new CacheResultFilter(introspector),
res.getWriter().print("Success"); new AuthFilter(introspector, corsConfig1),
} (req, res, chain) -> chain.doFilter(new MockHttpServletRequest("GET", "/2"), res),
catch (Exception ex) { new CacheResultFilter(introspector),
throw new IllegalStateException(ex); new AuthFilter(introspector, corsConfig2));
}
}
};
new MockFilterChain(servlet, introspector.createCacheFilter(), filter) MockHttpServletResponse response = new MockHttpServletResponse();
.doFilter(request, response); filterChain.doFilter(new MockHttpServletRequest("GET", "/1"), response);
assertThat(response.getContentAsString()).isEqualTo("Success"); assertThat(response.getContentAsString()).isEqualTo("Success");
assertThat(mapping.getInvocationCount()).isEqualTo(1); assertThat(mapping1.getInvocationCount()).isEqualTo(2);
assertThat(mapping2.getInvocationCount()).isEqualTo(1);
assertThat(mapping1.getMatchCount()).isEqualTo(1);
assertThat(mapping2.getMatchCount()).isEqualTo(1);
} }
private HandlerMappingIntrospector initIntrospector(WebApplicationContext context) { private HandlerMappingIntrospector initIntrospector(TestMatchableHandlerMapping... mappings) {
StaticWebApplicationContext context = new StaticWebApplicationContext();
int index = 0;
for (TestMatchableHandlerMapping mapping : mappings) {
context.registerBean("mapping" + index++, TestMatchableHandlerMapping.class, () -> mapping);
}
context.refresh();
return initIntrospector(context);
}
private static HandlerMappingIntrospector initIntrospector(WebApplicationContext context) {
HandlerMappingIntrospector introspector = new HandlerMappingIntrospector(); HandlerMappingIntrospector introspector = new HandlerMappingIntrospector();
introspector.setApplicationContext(context); introspector.setApplicationContext(context);
introspector.afterPropertiesSet(); introspector.afterPropertiesSet();
@ -327,23 +331,111 @@ public class HandlerMappingIntrospectorTests {
} }
private static class TestMatchableHandlerMapping implements MatchableHandlerMapping { private static class TestMatchableHandlerMapping extends SimpleUrlHandlerMapping {
private int invocationCount; private int invocationCount;
private int matchCount;
public int getInvocationCount() { public int getInvocationCount() {
return this.invocationCount; return this.invocationCount;
} }
public int getMatchCount() {
return this.matchCount;
}
@Override @Override
public HandlerExecutionChain getHandler(HttpServletRequest request) { protected Object getHandlerInternal(HttpServletRequest request) throws Exception {
this.invocationCount++; this.invocationCount++;
return new HandlerExecutionChain(new Object()); Object handler = super.getHandlerInternal(request);
if (handler != null) {
this.matchCount++;
}
return handler;
}
}
private static class TestHandler implements CorsConfigurationSource {
private final CorsConfiguration corsConfig;
private TestHandler(CorsConfiguration corsConfig) {
this.corsConfig = corsConfig;
} }
@Override @Override
public RequestMatchResult match(HttpServletRequest request, String pattern) { public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
throw new UnsupportedOperationException(); return this.corsConfig;
}
}
private static class CacheResultFilter implements Filter {
private final HandlerMappingIntrospector introspector;
private CacheResultFilter(HandlerMappingIntrospector introspector) {
this.introspector = introspector;
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
throws ServletException {
CachedResult previousValue = this.introspector.setCache((HttpServletRequest) req);
try {
chain.doFilter(req, res);
}
catch (Exception ex) {
throw new ServletException("HandlerMapping introspection failed", ex);
}
finally {
this.introspector.resetCache(req, previousValue);
}
}
}
private static class AuthFilter implements Filter {
private final HandlerMappingIntrospector introspector;
private final CorsConfiguration corsConfig;
private AuthFilter(HandlerMappingIntrospector introspector, CorsConfiguration corsConfig) {
this.introspector = introspector;
this.corsConfig = corsConfig;
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
try {
for (int i = 0; i < 10; i++) {
HttpServletRequest httpRequest = (HttpServletRequest) req;
assertThat(introspector.getMatchableHandlerMapping(httpRequest)).isNotNull();
assertThat(introspector.getCorsConfiguration(httpRequest)).isSameAs(corsConfig);
}
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
chain.doFilter(req, res);
}
}
private static class TestServlet extends HttpServlet {
@Override
protected void service(HttpServletRequest req, HttpServletResponse res) {
try {
res.getWriter().print("Success");
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
} }
} }

Loading…
Cancel
Save