diff --git a/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java b/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java index fa76bd8179..6d89492487 100644 --- a/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java +++ b/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java @@ -15,8 +15,15 @@ */ package org.springframework.test.web.servlet; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import javax.servlet.http.HttpServletRequest; + import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.servlet.FlashMap; import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.ModelAndView; @@ -43,6 +50,8 @@ class DefaultMvcResult implements MvcResult { private Exception resolvedException; + private CountDownLatch asyncResultLatch; + /** * Create a new instance with the given request and response. @@ -96,4 +105,36 @@ class DefaultMvcResult implements MvcResult { return RequestContextUtils.getOutputFlashMap(mockRequest); } + public void setAsyncResultLatch(CountDownLatch asyncResultLatch) { + this.asyncResultLatch = asyncResultLatch; + } + + public Object getAsyncResult() { + WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(this.mockRequest); + if (asyncManager.isConcurrentHandlingStarted()) { + if (!awaitAsyncResult()) { + throw new IllegalStateException( + "Gave up waiting on async result from [" + this.handler + "] to complete"); + } + if (asyncManager.hasConcurrentResult()) { + return asyncManager.getConcurrentResult(); + } + } + + return null; + } + + private boolean awaitAsyncResult() { + if (this.asyncResultLatch == null) { + return true; + } + long timeout = ((HttpServletRequest) this.mockRequest).getAsyncContext().getTimeout(); + try { + return this.asyncResultLatch.await(timeout, TimeUnit.MILLISECONDS); + } + catch (InterruptedException e) { + return false; + } + } + } diff --git a/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/MvcResult.java b/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/MvcResult.java index 88e11dda3e..c24ba48365 100644 --- a/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/MvcResult.java +++ b/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/MvcResult.java @@ -75,4 +75,14 @@ public interface MvcResult { */ FlashMap getFlashMap(); + /** + * Get the result of asynchronous execution or {@code null} if concurrent + * handling did not start. This method will hold and await the completion + * of concurrent handling. + * + * @throws IllegalStateException if concurrent handling does not complete + * within the allocated async timeout value. + */ + Object getAsyncResult(); + } diff --git a/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java b/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java index 27f14ce150..eccdead211 100644 --- a/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java +++ b/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java @@ -19,7 +19,6 @@ package org.springframework.test.web.servlet; import java.io.IOException; import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import javax.servlet.ServletException; import javax.servlet.ServletRequest; @@ -56,32 +55,41 @@ final class TestDispatcherServlet extends DispatcherServlet { super(webApplicationContext); } - protected DefaultMvcResult getMvcResult(ServletRequest request) { - return (DefaultMvcResult) request.getAttribute(MockMvc.MVC_RESULT_ATTRIBUTE); - } - @Override protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { - WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request); + CountDownLatch latch = registerAsyncInterceptors(request); + getMvcResult(request).setAsyncResultLatch(latch); - TestCallableInterceptor callableInterceptor = new TestCallableInterceptor(); - asyncManager.registerCallableInterceptor("mock-mvc", callableInterceptor); + super.service(request, response); + } - TestDeferredResultInterceptor deferredResultInterceptor = new TestDeferredResultInterceptor(); - asyncManager.registerDeferredResultInterceptor("mock-mvc", deferredResultInterceptor); + private CountDownLatch registerAsyncInterceptors(HttpServletRequest request) { - super.service(request, response); + final CountDownLatch asyncResultLatch = new CountDownLatch(1); + + WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request); - // TODO: add CountDownLatch to DeferredResultInterceptor and wait in request().asyncResult(..) + asyncManager.registerCallableInterceptor("mockmvc", new CallableProcessingInterceptor() { + public void preProcess(NativeWebRequest request, Callable task) throws Exception { } + public void postProcess(NativeWebRequest request, Callable task, Object value) throws Exception { + asyncResultLatch.countDown(); + } + }); - Object handler = getMvcResult(request).getHandler(); - if (asyncManager.isConcurrentHandlingStarted() && !deferredResultInterceptor.wasInvoked) { - if (!callableInterceptor.await()) { - throw new ServletException( - "Gave up waiting on Callable from [" + handler.getClass().getName() + "] to complete"); + asyncManager.registerDeferredResultInterceptor("mockmvc", new DeferredResultProcessingInterceptor() { + public void preProcess(NativeWebRequest request, DeferredResult result) throws Exception { } + public void postProcess(NativeWebRequest request, DeferredResult result, Object value) throws Exception { + asyncResultLatch.countDown(); } - } + public void afterExpiration(NativeWebRequest request, DeferredResult result) throws Exception { } + }); + + return asyncResultLatch; + } + + protected DefaultMvcResult getMvcResult(ServletRequest request) { + return (DefaultMvcResult) request.getAttribute(MockMvc.MVC_RESULT_ATTRIBUTE); } @Override @@ -118,38 +126,4 @@ final class TestDispatcherServlet extends DispatcherServlet { return mav; } - - private final class TestCallableInterceptor implements CallableProcessingInterceptor { - - private final CountDownLatch latch = new CountDownLatch(1); - - private boolean await() { - try { - return this.latch.await(5, TimeUnit.SECONDS); - } - catch (InterruptedException e) { - return false; - } - } - - public void preProcess(NativeWebRequest request, Callable task) { } - - public void postProcess(NativeWebRequest request, Callable task, Object concurrentResult) { - this.latch.countDown(); - } - } - - private final class TestDeferredResultInterceptor implements DeferredResultProcessingInterceptor { - - private boolean wasInvoked; - - public void preProcess(NativeWebRequest request, DeferredResult deferredResult) { - this.wasInvoked = true; - } - - public void postProcess(NativeWebRequest request, DeferredResult deferredResult, Object concurrentResult) { } - - public void afterExpiration(NativeWebRequest request, DeferredResult deferredResult) { } - } - } diff --git a/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/request/MockAsyncContext.java b/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/request/MockAsyncContext.java index c13ad69a82..c5bc08dc4b 100644 --- a/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/request/MockAsyncContext.java +++ b/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/request/MockAsyncContext.java @@ -48,7 +48,7 @@ class MockAsyncContext implements AsyncContext { private String dispatchedPath; - private long timeout = 10 * 60 * 1000L; // 10 seconds is Tomcat's default + private long timeout = 10 * 1000L; // 10 seconds is Tomcat's default public MockAsyncContext(ServletRequest request, ServletResponse response) { diff --git a/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/result/RequestResultMatchers.java b/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/result/RequestResultMatchers.java index 468060f80d..2f7c84390b 100644 --- a/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/result/RequestResultMatchers.java +++ b/spring-test-mvc/src/main/java/org/springframework/test/web/servlet/result/RequestResultMatchers.java @@ -30,8 +30,6 @@ import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultMatcher; import org.springframework.web.context.request.async.AsyncTask; import org.springframework.web.context.request.async.DeferredResult; -import org.springframework.web.context.request.async.WebAsyncManager; -import org.springframework.web.context.request.async.WebAsyncUtils; /** * Factory for assertions on the request. An instance of this class is @@ -90,9 +88,8 @@ public class RequestResultMatchers { @SuppressWarnings("unchecked") public void match(MvcResult result) { HttpServletRequest request = result.getRequest(); - WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request); MatcherAssert.assertThat("Async started", request.isAsyncStarted(), equalTo(true)); - MatcherAssert.assertThat("Async result", (T) asyncManager.getConcurrentResult(), matcher); + MatcherAssert.assertThat("Async result", (T) result.getAsyncResult(), matcher); } }; } diff --git a/spring-test-mvc/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java b/spring-test-mvc/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java index 7ffaabae2e..6fb0e444d4 100644 --- a/spring-test-mvc/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java +++ b/spring-test-mvc/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java @@ -120,4 +120,8 @@ public class StubMvcResult implements MvcResult { this.response = response; } + public Object getAsyncResult() { + return null; + } + }