diff --git a/spring-web/src/main/java/org/springframework/web/filter/ServletRequestPathFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ServletRequestPathFilter.java index d1cda60b04..cc9e85015f 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ServletRequestPathFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ServletRequestPathFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; +import org.springframework.http.server.RequestPath; import org.springframework.web.util.ServletRequestPathUtils; /** @@ -48,12 +49,13 @@ public class ServletRequestPathFilter implements Filter { public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { + RequestPath previousRequestPath = (RequestPath) request.getAttribute(ServletRequestPathUtils.PATH_ATTRIBUTE); ServletRequestPathUtils.parseAndCache((HttpServletRequest) request); try { chain.doFilter(request, response); } finally { - ServletRequestPathUtils.clearParsedRequestPath(request); + ServletRequestPathUtils.setParsedRequestPath(previousRequestPath, request); } } diff --git a/spring-web/src/main/java/org/springframework/web/util/ServletRequestPathUtils.java b/spring-web/src/main/java/org/springframework/web/util/ServletRequestPathUtils.java index 200478ea55..8d2cf41675 100644 --- a/spring-web/src/main/java/org/springframework/web/util/ServletRequestPathUtils.java +++ b/spring-web/src/main/java/org/springframework/web/util/ServletRequestPathUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import javax.servlet.http.HttpServletRequest; import org.springframework.http.server.PathContainer; import org.springframework.http.server.RequestPath; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -71,6 +72,22 @@ public abstract class ServletRequestPathUtils { return path; } + /** + * Set the cached, parsed {@code RequestPath} to the given value. + * @param requestPath the value to set to, or if {@code null} the cache + * value is cleared. + * @param request the current request + * @since 5.3.3 + */ + public static void setParsedRequestPath(@Nullable RequestPath requestPath, ServletRequest request) { + if (requestPath != null) { + request.setAttribute(PATH_ATTRIBUTE, requestPath); + } + else { + request.removeAttribute(PATH_ATTRIBUTE); + } + } + /** * Check for a {@link #parseAndCache previously} parsed and cached {@code RequestPath}. */ diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java index bc7fc189e6..6b9016e503 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -952,9 +952,10 @@ public class DispatcherServlet extends FrameworkServlet { request.setAttribute(FLASH_MAP_MANAGER_ATTRIBUTE, this.flashMapManager); } - RequestPath requestPath = null; - if (this.parseRequestPath && !ServletRequestPathUtils.hasParsedRequestPath(request)) { - requestPath = ServletRequestPathUtils.parseAndCache(request); + RequestPath previousRequestPath = null; + if (this.parseRequestPath) { + previousRequestPath = (RequestPath) request.getAttribute(ServletRequestPathUtils.PATH_ATTRIBUTE); + ServletRequestPathUtils.parseAndCache(request); } try { @@ -967,9 +968,7 @@ public class DispatcherServlet extends FrameworkServlet { restoreAttributesAfterInclude(request, attributesSnapshot); } } - if (requestPath != null) { - ServletRequestPathUtils.clearParsedRequestPath(request); - } + ServletRequestPathUtils.setParsedRequestPath(previousRequestPath, request); } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/DispatcherServletTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/DispatcherServletTests.java index 0213b52847..adbae39e5f 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/DispatcherServletTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/DispatcherServletTests.java @@ -17,8 +17,11 @@ package org.springframework.web.servlet; import java.io.IOException; +import java.util.Collections; import java.util.Locale; +import java.util.Map; +import javax.servlet.DispatcherType; import javax.servlet.Servlet; import javax.servlet.ServletConfig; import javax.servlet.ServletContext; @@ -35,14 +38,18 @@ import org.springframework.beans.PropertyValue; import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.annotation.Bean; import org.springframework.core.env.ConfigurableEnvironment; import org.springframework.http.HttpHeaders; +import org.springframework.http.server.RequestPath; +import org.springframework.web.HttpRequestHandler; import org.springframework.web.context.ConfigurableWebApplicationContext; import org.springframework.web.context.ConfigurableWebEnvironment; import org.springframework.web.context.ContextLoader; import org.springframework.web.context.ServletConfigAwareBean; import org.springframework.web.context.ServletContextAwareBean; import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import org.springframework.web.context.support.StandardServletEnvironment; import org.springframework.web.context.support.StaticWebApplicationContext; import org.springframework.web.multipart.MaxUploadSizeExceededException; @@ -56,7 +63,9 @@ import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockHttpServletResponse; import org.springframework.web.testfixture.servlet.MockServletConfig; import org.springframework.web.testfixture.servlet.MockServletContext; +import org.springframework.web.util.ServletRequestPathUtils; import org.springframework.web.util.WebUtils; +import org.springframework.web.util.pattern.PathPatternParser; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -700,6 +709,27 @@ public class DispatcherServletTests { complexDispatcherServlet.service(request, response)); } + @Test // gh-26318 + public void parsedRequestPathIsRestoredOnForward() throws Exception { + AnnotationConfigWebApplicationContext context = new AnnotationConfigWebApplicationContext(); + context.register(PathPatternParserConfig.class); + DispatcherServlet servlet = new DispatcherServlet(context); + servlet.init(servletConfig); + + RequestPath previousRequestPath = RequestPath.parse("/", null); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/test"); + request.setDispatcherType(DispatcherType.FORWARD); + request.setAttribute(ServletRequestPathUtils.PATH_ATTRIBUTE, previousRequestPath); + + MockHttpServletResponse response = new MockHttpServletResponse(); + servlet.service(request, response); + + assertThat(response.getStatus()).isEqualTo(200); + assertThat(response.getContentAsString()).isEqualTo("test-body"); + assertThat(request.getAttribute(ServletRequestPathUtils.PATH_ATTRIBUTE)).isSameAs(previousRequestPath); + } + @Test public void dispatcherServletRefresh() throws ServletException { MockServletContext servletContext = new MockServletContext("org/springframework/web/context"); @@ -867,4 +897,22 @@ public class DispatcherServletTests { } } + + private static class PathPatternParserConfig { + + @Bean + public SimpleUrlHandlerMapping handlerMapping() { + Map urlMap = Collections.singletonMap("/test", + (HttpRequestHandler) (request, response) -> { + response.setStatus(200); + response.getWriter().print("test-body"); + }); + + SimpleUrlHandlerMapping mapping = new SimpleUrlHandlerMapping(); + mapping.setPatternParser(new PathPatternParser()); + mapping.setUrlMap(urlMap); + return mapping; + } + } + }