diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java index 06a486a440..ae223a81d7 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java @@ -27,7 +27,7 @@ import java.util.Properties; import java.util.function.BiFunction; import java.util.stream.Collectors; -import jakarta.servlet.Filter; +import jakarta.servlet.DispatcherType; import jakarta.servlet.ServletException; import jakarta.servlet.ServletRequest; import jakarta.servlet.http.HttpServletRequest; @@ -68,25 +68,27 @@ import org.springframework.web.util.pattern.PathPatternParser; * request. * * - *

Note: This is primarily an SPI to allow Spring Security + *

Note that this is primarily an SPI to allow Spring Security * to align its pattern matching with the same pattern matching that would be * used in Spring MVC for a given request, in order to avoid security issues. * Use of this introspector should be avoided for other purposes because it * incurs the overhead of resolving the handler for a request. * + *

Alternative security filter solutions that also rely on + * {@link HandlerMappingIntrospector} should consider adding an additional + * {@link jakarta.servlet.Filter} that invokes + * {@link #setCache(HttpServletRequest)} and {@link #resetCache(ServletRequest, CachedResult)} + * before and after delegating to the rest of the chain. Such a Filter should + * process all dispatcher types and should be ordered ahead of security filters. + * * @author Rossen Stoyanchev * @since 4.3.1 */ public class HandlerMappingIntrospector implements CorsConfigurationSource, ApplicationContextAware, InitializingBean { - static final String MAPPING_ATTRIBUTE = - HandlerMappingIntrospector.class.getName() + ".HandlerMapping"; - - static final String CORS_CONFIG_ATTRIBUTE = - HandlerMappingIntrospector.class.getName() + ".CorsConfig"; - - private static final CorsConfiguration NO_CORS_CONFIG = new CorsConfiguration(); + private static final String CACHED_RESULT_ATTRIBUTE = + HandlerMappingIntrospector.class.getName() + ".CachedResult"; @Nullable @@ -166,55 +168,43 @@ public class HandlerMappingIntrospector /** - * Return Filter that performs lookups, caches the results in request attributes, - * and clears the attributes after the filter chain returns. + * Perform a lookup and save the {@link CachedResult} as a request attribute. + * This method can be invoked from a filter before subsequent calls to + * {@link #getMatchableHandlerMapping(HttpServletRequest)} and + * {@link #getCorsConfiguration(HttpServletRequest)} to avoid repeated lookups. + * @param request the current request + * @return the previous {@link CachedResult}, if there is one from a parent dispatch + * @throws ServletException thrown the lookup fails for any reason * @since 6.0.14 */ - public Filter createCacheFilter() { - return (request, response, chain) -> { - MatchableHandlerMapping previousMapping = getCachedMapping(request); - CorsConfiguration previousCorsConfig = getCachedCorsConfiguration(request); + @Nullable + public CachedResult setCache(HttpServletRequest request) throws ServletException { + CachedResult previous = getAttribute(request); + if (previous == null || !previous.matches(request)) { try { - HttpServletRequest wrappedRequest = new AttributesPreservingRequest((HttpServletRequest) request); - doWithHandlerMapping(wrappedRequest, false, (mapping, executionChain) -> { - MatchableHandlerMapping matchableMapping = createMatchableHandlerMapping(mapping, wrappedRequest); - CorsConfiguration corsConfig = getCorsConfiguration(wrappedRequest, executionChain); - setCache(request, matchableMapping, corsConfig); - return null; + HttpServletRequest wrapped = new AttributesPreservingRequest(request); + CachedResult cachedResult = doWithHandlerMapping(wrapped, false, (mapping, executionChain) -> { + MatchableHandlerMapping matchableMapping = createMatchableHandlerMapping(mapping, wrapped); + CorsConfiguration corsConfig = getCorsConfiguration(wrapped, executionChain); + return new CachedResult(request, matchableMapping, corsConfig); }); - chain.doFilter(request, response); + request.setAttribute(CACHED_RESULT_ATTRIBUTE, + cachedResult != null ? cachedResult : new CachedResult(request, null, null)); } - catch (Exception ex) { + catch (Throwable ex) { throw new ServletException("HandlerMapping introspection failed", ex); } - finally { - setCache(request, previousMapping, previousCorsConfig); - } - }; - } - - @Nullable - private static MatchableHandlerMapping getCachedMapping(ServletRequest request) { - return (MatchableHandlerMapping) request.getAttribute(MAPPING_ATTRIBUTE); - } - - @Nullable - private static CorsConfiguration getCachedCorsConfiguration(ServletRequest request) { - return (CorsConfiguration) request.getAttribute(CORS_CONFIG_ATTRIBUTE); + } + return previous; } - private static void setCache( - ServletRequest request, @Nullable MatchableHandlerMapping mapping, - @Nullable CorsConfiguration corsConfig) { - - if (mapping != null) { - request.setAttribute(MAPPING_ATTRIBUTE, mapping); - request.setAttribute(CORS_CONFIG_ATTRIBUTE, (corsConfig != null ? corsConfig : NO_CORS_CONFIG)); - } - else { - request.removeAttribute(MAPPING_ATTRIBUTE); - request.removeAttribute(CORS_CONFIG_ATTRIBUTE); - } + /** + * Restore a previous {@link CachedResult}. This method can be invoked from + * a filter after delegating to the rest of the chain. + * @since 6.0.14 + */ + public void resetCache(ServletRequest request, @Nullable CachedResult cachedResult) { + request.setAttribute(CACHED_RESULT_ATTRIBUTE, cachedResult); } /** @@ -228,9 +218,9 @@ public class HandlerMappingIntrospector */ @Nullable public MatchableHandlerMapping getMatchableHandlerMapping(HttpServletRequest request) throws Exception { - MatchableHandlerMapping cachedMapping = getCachedMapping(request); - if (cachedMapping != null) { - return cachedMapping; + CachedResult cachedResult = getCachedResultFor(request); + if (cachedResult != null) { + return cachedResult.getHandlerMapping(); } HttpServletRequest requestToUse = new AttributesPreservingRequest(request); return doWithHandlerMapping(requestToUse, false, @@ -255,9 +245,9 @@ public class HandlerMappingIntrospector @Override @Nullable public CorsConfiguration getCorsConfiguration(HttpServletRequest request) { - CorsConfiguration cachedCorsConfiguration = getCachedCorsConfiguration(request); - if (cachedCorsConfiguration != null) { - return (cachedCorsConfiguration != NO_CORS_CONFIG ? cachedCorsConfiguration : null); + CachedResult cachedResult = getCachedResultFor(request); + if (cachedResult != null) { + return cachedResult.getCorsConfig(); } try { boolean ignoreException = true; @@ -322,6 +312,68 @@ public class HandlerMappingIntrospector return null; } + /** + * Return a {@link CachedResult} that matches the given request. + */ + @Nullable + private CachedResult getCachedResultFor(HttpServletRequest request) { + CachedResult result = getAttribute(request); + return (result != null && result.matches(request) ? result : null); + } + + @Nullable + private static CachedResult getAttribute(HttpServletRequest request) { + return (CachedResult) request.getAttribute(CACHED_RESULT_ATTRIBUTE); + } + + + /** + * Container for a {@link MatchableHandlerMapping} and {@link CorsConfiguration} + * for a given request identified by dispatcher type and requestURI. + * @since 6.0.14 + */ + public final static class CachedResult { + + private final DispatcherType dispatcherType; + + private final String requestURI; + + @Nullable + private final MatchableHandlerMapping handlerMapping; + + @Nullable + private final CorsConfiguration corsConfig; + + private CachedResult(HttpServletRequest request, + @Nullable MatchableHandlerMapping mapping, @Nullable CorsConfiguration config) { + + this.dispatcherType = request.getDispatcherType(); + this.requestURI = request.getRequestURI(); + this.handlerMapping = mapping; + this.corsConfig = config; + } + + public boolean matches(HttpServletRequest request) { + return (this.dispatcherType.equals(request.getDispatcherType()) && + this.requestURI.matches(request.getRequestURI())); + } + + @Nullable + public MatchableHandlerMapping getHandlerMapping() { + return this.handlerMapping; + } + + @Nullable + public CorsConfiguration getCorsConfig() { + return this.corsConfig; + } + + @Override + public String toString() { + return "CacheValue " + this.dispatcherType + " '" + this.requestURI + "'"; + } + } + /** * Request wrapper that buffers request attributes in order protect the diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java index e1080b23da..94ff2671a6 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java @@ -23,7 +23,10 @@ import java.util.Collections; import java.util.List; import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; @@ -42,12 +45,14 @@ import org.springframework.web.context.support.AnnotationConfigWebApplicationCon import org.springframework.web.context.support.GenericWebApplicationContext; import org.springframework.web.context.support.StaticWebApplicationContext; import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.cors.CorsConfigurationSource; import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.RouterFunctions; import org.springframework.web.servlet.function.ServerResponse; import org.springframework.web.servlet.function.support.RouterFunctionMapping; +import org.springframework.web.servlet.handler.HandlerMappingIntrospector.CachedResult; import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping; import org.springframework.web.testfixture.servlet.MockFilterChain; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; @@ -202,68 +207,67 @@ public class HandlerMappingIntrospectorTests { @Test void cacheFilter() throws Exception { - testCacheFilter(new MockHttpServletRequest()); - } + CorsConfiguration corsConfig = new CorsConfiguration(); + TestMatchableHandlerMapping mapping = new TestMatchableHandlerMapping(); + mapping.registerHandler("/test", new TestHandler(corsConfig)); - @Test - void cacheFilterRestoresPreviousValues() throws Exception { - TestMatchableHandlerMapping previousMapping = new TestMatchableHandlerMapping(); - CorsConfiguration previousCorsConfig = new CorsConfiguration(); + HandlerMappingIntrospector introspector = initIntrospector(mapping); - MockHttpServletRequest request = new MockHttpServletRequest(); - request.setAttribute(HandlerMappingIntrospector.MAPPING_ATTRIBUTE, previousMapping); - request.setAttribute(HandlerMappingIntrospector.CORS_CONFIG_ATTRIBUTE, previousCorsConfig); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/test"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + MockFilterChain filterChain = new MockFilterChain( + new TestServlet(), new CacheResultFilter(introspector), new AuthFilter(introspector, corsConfig)); - testCacheFilter(request); + filterChain.doFilter(request, response); - assertThat(previousMapping.getInvocationCount()).isEqualTo(0); - assertThat(request.getAttribute(HandlerMappingIntrospector.MAPPING_ATTRIBUTE)).isSameAs(previousMapping); - assertThat(request.getAttribute(HandlerMappingIntrospector.CORS_CONFIG_ATTRIBUTE)).isSameAs(previousCorsConfig); + assertThat(response.getContentAsString()).isEqualTo("Success"); + assertThat(mapping.getInvocationCount()).isEqualTo(1); + assertThat(mapping.getMatchCount()).isEqualTo(1); } - private void testCacheFilter(MockHttpServletRequest request) throws IOException, ServletException { - TestMatchableHandlerMapping mapping = new TestMatchableHandlerMapping(); - StaticWebApplicationContext context = new StaticWebApplicationContext(); - context.registerBean(TestMatchableHandlerMapping.class, () -> mapping); - context.refresh(); + @Test + void cacheFilterWithNestedDispatch() throws Exception { + CorsConfiguration corsConfig1 = new CorsConfiguration(); + CorsConfiguration corsConfig2 = new CorsConfiguration(); - HandlerMappingIntrospector introspector = initIntrospector(context); - MockHttpServletResponse response = new MockHttpServletResponse(); + TestMatchableHandlerMapping mapping1 = new TestMatchableHandlerMapping(); + TestMatchableHandlerMapping mapping2 = new TestMatchableHandlerMapping(); - Filter filter = (req, res, chain) -> { - try { - for (int i = 0; i < 10; i++) { - introspector.getMatchableHandlerMapping((HttpServletRequest) req); - introspector.getCorsConfiguration((HttpServletRequest) req); - } - } - catch (Exception ex) { - throw new IllegalStateException(ex); - } - chain.doFilter(req, res); - }; + mapping1.registerHandler("/1", new TestHandler(corsConfig1)); + mapping2.registerHandler("/2", new TestHandler(corsConfig2)); - HttpServlet servlet = new HttpServlet() { + HandlerMappingIntrospector introspector = initIntrospector(mapping1, mapping2); - @Override - protected void service(HttpServletRequest req, HttpServletResponse res) { - try { - res.getWriter().print("Success"); - } - catch (Exception ex) { - throw new IllegalStateException(ex); - } - } - }; + MockFilterChain filterChain = new MockFilterChain( + new TestServlet(), + new CacheResultFilter(introspector), + new AuthFilter(introspector, corsConfig1), + (req, res, chain) -> chain.doFilter(new MockHttpServletRequest("GET", "/2"), res), + new CacheResultFilter(introspector), + new AuthFilter(introspector, corsConfig2)); - new MockFilterChain(servlet, introspector.createCacheFilter(), filter) - .doFilter(request, response); + MockHttpServletResponse response = new MockHttpServletResponse(); + filterChain.doFilter(new MockHttpServletRequest("GET", "/1"), response); assertThat(response.getContentAsString()).isEqualTo("Success"); - assertThat(mapping.getInvocationCount()).isEqualTo(1); + assertThat(mapping1.getInvocationCount()).isEqualTo(2); + assertThat(mapping2.getInvocationCount()).isEqualTo(1); + assertThat(mapping1.getMatchCount()).isEqualTo(1); + assertThat(mapping2.getMatchCount()).isEqualTo(1); } - private HandlerMappingIntrospector initIntrospector(WebApplicationContext context) { + private HandlerMappingIntrospector initIntrospector(TestMatchableHandlerMapping... mappings) { + StaticWebApplicationContext context = new StaticWebApplicationContext(); + int index = 0; + for (TestMatchableHandlerMapping mapping : mappings) { + context.registerBean("mapping" + index++, TestMatchableHandlerMapping.class, () -> mapping); + } + context.refresh(); + return initIntrospector(context); + } + + private static HandlerMappingIntrospector initIntrospector(WebApplicationContext context) { HandlerMappingIntrospector introspector = new HandlerMappingIntrospector(); introspector.setApplicationContext(context); introspector.afterPropertiesSet(); @@ -327,23 +331,111 @@ public class HandlerMappingIntrospectorTests { } - private static class TestMatchableHandlerMapping implements MatchableHandlerMapping { + private static class TestMatchableHandlerMapping extends SimpleUrlHandlerMapping { private int invocationCount; + private int matchCount; + public int getInvocationCount() { return this.invocationCount; } + public int getMatchCount() { + return this.matchCount; + } + @Override - public HandlerExecutionChain getHandler(HttpServletRequest request) { + protected Object getHandlerInternal(HttpServletRequest request) throws Exception { this.invocationCount++; - return new HandlerExecutionChain(new Object()); + Object handler = super.getHandlerInternal(request); + if (handler != null) { + this.matchCount++; + } + return handler; + } + } + + + private static class TestHandler implements CorsConfigurationSource { + + private final CorsConfiguration corsConfig; + + private TestHandler(CorsConfiguration corsConfig) { + this.corsConfig = corsConfig; } @Override - public RequestMatchResult match(HttpServletRequest request, String pattern) { - throw new UnsupportedOperationException(); + public CorsConfiguration getCorsConfiguration(HttpServletRequest request) { + return this.corsConfig; + } + } + + + private static class CacheResultFilter implements Filter { + + private final HandlerMappingIntrospector introspector; + + private CacheResultFilter(HandlerMappingIntrospector introspector) { + this.introspector = introspector; + } + + @Override + public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) + throws ServletException { + + CachedResult previousValue = this.introspector.setCache((HttpServletRequest) req); + try { + chain.doFilter(req, res); + } + catch (Exception ex) { + throw new ServletException("HandlerMapping introspection failed", ex); + } + finally { + this.introspector.resetCache(req, previousValue); + } + } + } + + + private static class AuthFilter implements Filter { + + private final HandlerMappingIntrospector introspector; + + private final CorsConfiguration corsConfig; + + private AuthFilter(HandlerMappingIntrospector introspector, CorsConfiguration corsConfig) { + this.introspector = introspector; + this.corsConfig = corsConfig; + } + + @Override + public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException { + try { + for (int i = 0; i < 10; i++) { + HttpServletRequest httpRequest = (HttpServletRequest) req; + assertThat(introspector.getMatchableHandlerMapping(httpRequest)).isNotNull(); + assertThat(introspector.getCorsConfiguration(httpRequest)).isSameAs(corsConfig); + } + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + chain.doFilter(req, res); + } + } + + + private static class TestServlet extends HttpServlet { + + @Override + protected void service(HttpServletRequest req, HttpServletResponse res) { + try { + res.getWriter().print("Success"); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } } }