From 7a896d7d80a9bdbbbbe5f987c3f79c7578cbb66c Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Fri, 6 Apr 2018 11:01:41 -0400 Subject: [PATCH] TestDispatcherServlet unwraps to find mock request Issue: SPR-16695 --- .../web/servlet/TestDispatcherServlet.java | 10 +- .../samples/standalone/FilterTests.java | 98 ++++++++++++++++++- 2 files changed, 103 insertions(+), 5 deletions(-) diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java index a85592b2b7..4b6e4e6bcc 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java @@ -26,6 +26,8 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.lang.Nullable; import org.springframework.mock.web.MockAsyncContext; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.util.Assert; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.async.CallableProcessingInterceptor; @@ -35,6 +37,7 @@ import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.ModelAndView; +import org.springframework.web.util.WebUtils; /** * A sub-class of {@code DispatcherServlet} that saves the result in an @@ -68,8 +71,13 @@ final class TestDispatcherServlet extends DispatcherServlet { super.service(request, response); if (request.getAsyncContext() != null) { + MockHttpServletRequest mockRequest = WebUtils.getNativeRequest(request, MockHttpServletRequest.class); + Assert.notNull(mockRequest, "Expected MockHttpServletRequest"); + MockAsyncContext mockAsyncContext = ((MockAsyncContext) mockRequest.getAsyncContext()); + Assert.notNull(mockAsyncContext, "MockAsyncContext not found. Did request wrapper not delegate startAsync?"); + CountDownLatch dispatchLatch = new CountDownLatch(1); - ((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(dispatchLatch::countDown); + mockAsyncContext.addDispatchHandler(dispatchLatch::countDown); getMvcResult(request).setAsyncDispatchLatch(dispatchLatch); } } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java index 0f7bc970c9..417778ead5 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,9 +19,14 @@ package org.springframework.test.web.servlet.samples.standalone; import java.io.IOException; import java.security.Principal; import java.util.concurrent.CompletableFuture; +import javax.servlet.AsyncContext; +import javax.servlet.AsyncListener; import javax.servlet.Filter; import javax.servlet.FilterChain; +import javax.servlet.ServletContext; import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; @@ -120,10 +125,10 @@ public class FilterTests { .andExpect(model().attribute("principal", WrappingRequestResponseFilter.PRINCIPAL_NAME)); } - @Test // SPR-16067 - public void filterWrapsRequestResponseWithAsyncDispatch() throws Exception { + @Test // SPR-16067, SPR-16695 + public void filterWrapsRequestResponseAndPerformsAsyncDispatch() throws Exception { MockMvc mockMvc = standaloneSetup(new PersonController()) - .addFilters(new ShallowEtagHeaderFilter()) + .addFilters(new WrappingRequestResponseFilter(), new ShallowEtagHeaderFilter()) .build(); MvcResult mvcResult = mockMvc.perform(get("/persons/1").accept(MediaType.APPLICATION_JSON)) @@ -189,10 +194,20 @@ public class FilterTests { FilterChain filterChain) throws ServletException, IOException { filterChain.doFilter(new HttpServletRequestWrapper(request) { + @Override public Principal getUserPrincipal() { return () -> PRINCIPAL_NAME; } + + // Like Spring Security does in HttpServlet3RequestFactory.. + + @Override + public AsyncContext getAsyncContext() { + return super.getAsyncContext() != null ? + new AsyncContextWrapper(super.getAsyncContext()) : null; + } + }, new HttpServletResponseWrapper(response)); } } @@ -206,4 +221,79 @@ public class FilterTests { response.sendRedirect("/login"); } } + + + private static class AsyncContextWrapper implements AsyncContext { + + private final AsyncContext delegate; + + public AsyncContextWrapper(AsyncContext delegate) { + this.delegate = delegate; + } + + @Override + public ServletRequest getRequest() { + return this.delegate.getRequest(); + } + + @Override + public ServletResponse getResponse() { + return this.delegate.getResponse(); + } + + @Override + public boolean hasOriginalRequestAndResponse() { + return this.delegate.hasOriginalRequestAndResponse(); + } + + @Override + public void dispatch() { + this.delegate.dispatch(); + } + + @Override + public void dispatch(String path) { + this.delegate.dispatch(path); + } + + @Override + public void dispatch(ServletContext context, String path) { + this.delegate.dispatch(context, path); + } + + @Override + public void complete() { + this.delegate.complete(); + } + + @Override + public void start(Runnable run) { + this.delegate.start(run); + } + + @Override + public void addListener(AsyncListener listener) { + this.delegate.addListener(listener); + } + + @Override + public void addListener(AsyncListener listener, ServletRequest req, ServletResponse res) { + this.delegate.addListener(listener, req, res); + } + + @Override + public T createListener(Class clazz) throws ServletException { + return this.delegate.createListener(clazz); + } + + @Override + public void setTimeout(long timeout) { + this.delegate.setTimeout(timeout); + } + + @Override + public long getTimeout() { + return this.delegate.getTimeout(); + } + } }