diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java index 9eaa30deb0..93cbc3884b 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java @@ -27,6 +27,9 @@ import java.util.Properties; import java.util.function.BiFunction; import java.util.stream.Collectors; +import jakarta.servlet.Filter; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequestWrapper; @@ -77,6 +80,15 @@ import org.springframework.web.util.pattern.PathPatternParser; public class HandlerMappingIntrospector implements CorsConfigurationSource, ApplicationContextAware, InitializingBean { + static final String MAPPING_ATTRIBUTE = + HandlerMappingIntrospector.class.getName() + ".HandlerMapping"; + + static final String CORS_CONFIG_ATTRIBUTE = + HandlerMappingIntrospector.class.getName() + ".CorsConfig"; + + private static final CorsConfiguration NO_CORS_CONFIG = new CorsConfiguration(); + + @Nullable private ApplicationContext applicationContext; @@ -153,6 +165,58 @@ public class HandlerMappingIntrospector } + /** + * Return Filter that performs lookups, caches the results in request attributes, + * and clears the attributes after the filter chain returns. + * @since 6.0.14 + */ + public Filter createCacheFilter() { + return (request, response, chain) -> { + MatchableHandlerMapping previousMapping = getCachedMapping(request); + CorsConfiguration previousCorsConfig = getCachedCorsConfiguration(request); + try { + HttpServletRequest wrappedRequest = new AttributesPreservingRequest((HttpServletRequest) request); + doWithHandlerMapping(wrappedRequest, false, (mapping, executionChain) -> { + MatchableHandlerMapping matchableMapping = createMatchableHandlerMapping(mapping, wrappedRequest); + CorsConfiguration corsConfig = getCorsConfiguration(wrappedRequest, executionChain); + setCache(request, matchableMapping, corsConfig); + return null; + }); + chain.doFilter(request, response); + } + catch (Exception ex) { + throw new ServletException("HandlerMapping introspection failed", ex); + } + finally { + setCache(request, previousMapping, previousCorsConfig); + } + }; + } + + @Nullable + private static MatchableHandlerMapping getCachedMapping(ServletRequest request) { + return (MatchableHandlerMapping) request.getAttribute(MAPPING_ATTRIBUTE); + } + + @Nullable + private static CorsConfiguration getCachedCorsConfiguration(ServletRequest request) { + return (CorsConfiguration) request.getAttribute(CORS_CONFIG_ATTRIBUTE); + } + + private static void setCache( + ServletRequest request, @Nullable MatchableHandlerMapping mapping, + @Nullable CorsConfiguration corsConfig) { + + if (mapping != null) { + request.setAttribute(MAPPING_ATTRIBUTE, mapping); + request.setAttribute(CORS_CONFIG_ATTRIBUTE, (corsConfig != null ? corsConfig : NO_CORS_CONFIG)); + } + else { + request.removeAttribute(MAPPING_ATTRIBUTE); + request.removeAttribute(CORS_CONFIG_ATTRIBUTE); + } + } + /** * Find the {@link HandlerMapping} that would handle the given request and * return a {@link MatchableHandlerMapping} to use for path matching. @@ -164,6 +228,10 @@ public class HandlerMappingIntrospector */ @Nullable public MatchableHandlerMapping getMatchableHandlerMapping(HttpServletRequest request) throws Exception { + MatchableHandlerMapping cachedMapping = getCachedMapping(request); + if (cachedMapping != null) { + return cachedMapping; + } HttpServletRequest requestToUse = new AttributesPreservingRequest(request); return doWithHandlerMapping(requestToUse, false, (mapping, executionChain) -> createMatchableHandlerMapping(mapping, requestToUse)); @@ -187,6 +255,10 @@ public class HandlerMappingIntrospector @Override @Nullable public CorsConfiguration getCorsConfiguration(HttpServletRequest request) { + CorsConfiguration cachedCorsConfiguration = getCachedCorsConfiguration(request); + if (cachedCorsConfiguration != null) { + return (cachedCorsConfiguration != NO_CORS_CONFIG ? cachedCorsConfiguration : null); + } try { boolean ignoreException = true; AttributesPreservingRequest requestToUse = new AttributesPreservingRequest(request); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java index d020753e79..e1080b23da 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java @@ -16,12 +16,17 @@ package org.springframework.web.servlet.handler; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import jakarta.servlet.Filter; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -44,7 +49,9 @@ import org.springframework.web.servlet.function.RouterFunctions; import org.springframework.web.servlet.function.ServerResponse; import org.springframework.web.servlet.function.support.RouterFunctionMapping; import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping; +import org.springframework.web.testfixture.servlet.MockFilterChain; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; +import org.springframework.web.testfixture.servlet.MockHttpServletResponse; import org.springframework.web.util.ServletRequestPathUtils; import org.springframework.web.util.pattern.PathPattern; import org.springframework.web.util.pattern.PathPatternParser; @@ -137,7 +144,7 @@ public class HandlerMappingIntrospectorTests { @Test void getMatchableWhereHandlerMappingDoesNotImplementMatchableInterface() { StaticWebApplicationContext cxt = new StaticWebApplicationContext(); - cxt.registerSingleton("mapping", TestHandlerMapping.class); + cxt.registerBean("mapping", HandlerMapping.class, () -> request -> new HandlerExecutionChain(new Object())); cxt.refresh(); MockHttpServletRequest request = new MockHttpServletRequest(); @@ -193,6 +200,69 @@ public class HandlerMappingIntrospectorTests { assertThat(corsConfig.getAllowedMethods()).isEqualTo(Collections.singletonList("POST")); } + @Test + void cacheFilter() throws Exception { + testCacheFilter(new MockHttpServletRequest()); + } + + @Test + void cacheFilterRestoresPreviousValues() throws Exception { + TestMatchableHandlerMapping previousMapping = new TestMatchableHandlerMapping(); + CorsConfiguration previousCorsConfig = new CorsConfiguration(); + + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setAttribute(HandlerMappingIntrospector.MAPPING_ATTRIBUTE, previousMapping); + request.setAttribute(HandlerMappingIntrospector.CORS_CONFIG_ATTRIBUTE, previousCorsConfig); + + testCacheFilter(request); + + assertThat(previousMapping.getInvocationCount()).isEqualTo(0); + assertThat(request.getAttribute(HandlerMappingIntrospector.MAPPING_ATTRIBUTE)).isSameAs(previousMapping); + assertThat(request.getAttribute(HandlerMappingIntrospector.CORS_CONFIG_ATTRIBUTE)).isSameAs(previousCorsConfig); + } + + private void testCacheFilter(MockHttpServletRequest request) throws IOException, ServletException { + TestMatchableHandlerMapping mapping = new TestMatchableHandlerMapping(); + StaticWebApplicationContext context = new StaticWebApplicationContext(); + context.registerBean(TestMatchableHandlerMapping.class, () -> mapping); + context.refresh(); + + HandlerMappingIntrospector introspector = initIntrospector(context); + MockHttpServletResponse response = new MockHttpServletResponse(); + + Filter filter = (req, res, chain) -> { + try { + for (int i = 0; i < 10; i++) { + introspector.getMatchableHandlerMapping((HttpServletRequest) req); + introspector.getCorsConfiguration((HttpServletRequest) req); + } + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + chain.doFilter(req, res); + }; + + HttpServlet servlet = new HttpServlet() { + + @Override + protected void service(HttpServletRequest req, HttpServletResponse res) { + try { + res.getWriter().print("Success"); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + }; + + new MockFilterChain(servlet, introspector.createCacheFilter(), filter) + .doFilter(request, response); + + assertThat(response.getContentAsString()).isEqualTo("Success"); + assertThat(mapping.getInvocationCount()).isEqualTo(1); + } + private HandlerMappingIntrospector initIntrospector(WebApplicationContext context) { HandlerMappingIntrospector introspector = new HandlerMappingIntrospector(); introspector.setApplicationContext(context); @@ -201,15 +271,6 @@ public class HandlerMappingIntrospectorTests { } - private static class TestHandlerMapping implements HandlerMapping { - - @Override - public HandlerExecutionChain getHandler(HttpServletRequest request) { - return new HandlerExecutionChain(new Object()); - } - } - - @Configuration static class TestConfig { @@ -248,6 +309,7 @@ public class HandlerMappingIntrospectorTests { } } + private static class TestPathPatternParser extends PathPatternParser { private final List parsedPatterns = new ArrayList<>(); @@ -264,4 +326,25 @@ public class HandlerMappingIntrospectorTests { } } + + private static class TestMatchableHandlerMapping implements MatchableHandlerMapping { + + private int invocationCount; + + public int getInvocationCount() { + return this.invocationCount; + } + + @Override + public HandlerExecutionChain getHandler(HttpServletRequest request) { + this.invocationCount++; + return new HandlerExecutionChain(new Object()); + } + + @Override + public RequestMatchResult match(HttpServletRequest request, String pattern) { + throw new UnsupportedOperationException(); + } + } + }