@ -19,7 +19,6 @@ package org.springframework.test.web.servlet;
@@ -19,7 +19,6 @@ package org.springframework.test.web.servlet;
import java.io.IOException ;
import java.util.concurrent.Callable ;
import java.util.concurrent.CountDownLatch ;
import java.util.concurrent.TimeUnit ;
import javax.servlet.ServletException ;
import javax.servlet.ServletRequest ;
@ -56,32 +55,41 @@ final class TestDispatcherServlet extends DispatcherServlet {
@@ -56,32 +55,41 @@ final class TestDispatcherServlet extends DispatcherServlet {
super ( webApplicationContext ) ;
}
protected DefaultMvcResult getMvcResult ( ServletRequest request ) {
return ( DefaultMvcResult ) request . getAttribute ( MockMvc . MVC_RESULT_ATTRIBUTE ) ;
}
@Override
protected void service ( HttpServletRequest request , HttpServletResponse response ) throws ServletException , IOException {
WebAsyncManager asyncManager = WebAsyncUtils . getAsyncManager ( request ) ;
CountDownLatch latch = registerAsyncInterceptors ( request ) ;
getMvcResult ( request ) . setAsyncResultLatch ( latch ) ;
TestCallableInterceptor callableInterceptor = new TestCallableInterceptor ( ) ;
asyncManager . registerCallableInterceptor ( "mock-mvc" , callableInterceptor ) ;
super . service ( request , response ) ;
}
TestDeferredResultInterceptor deferredResultInterceptor = new TestDeferredResultInterceptor ( ) ;
asyncManager . registerDeferredResultInterceptor ( "mock-mvc" , deferredResultInterceptor ) ;
private CountDownLatch registerAsyncInterceptors ( HttpServletRequest request ) {
super . service ( request , response ) ;
final CountDownLatch asyncResultLatch = new CountDownLatch ( 1 ) ;
WebAsyncManager asyncManager = WebAsyncUtils . getAsyncManager ( request ) ;
// TODO: add CountDownLatch to DeferredResultInterceptor and wait in request().asyncResult(..)
asyncManager . registerCallableInterceptor ( "mockmvc" , new CallableProcessingInterceptor ( ) {
public void preProcess ( NativeWebRequest request , Callable < ? > task ) throws Exception { }
public void postProcess ( NativeWebRequest request , Callable < ? > task , Object value ) throws Exception {
asyncResultLatch . countDown ( ) ;
}
} ) ;
Object handler = getMvcResult ( request ) . getHandler ( ) ;
if ( asyncManager . isConcurrentHandlingStarted ( ) & & ! deferredResultInterceptor . wasInvoked ) {
if ( ! callableInterceptor . await ( ) ) {
throw new ServletException (
"Gave up waiting on Callable from [" + handler . getClass ( ) . getName ( ) + "] to complete" ) ;
asyncManager . registerDeferredResultInterceptor ( "mockmvc" , new DeferredResultProcessingInterceptor ( ) {
public void preProcess ( NativeWebRequest request , DeferredResult < ? > result ) throws Exception { }
public void postProcess ( NativeWebRequest request , DeferredResult < ? > result , Object value ) throws Exception {
asyncResultLatch . countDown ( ) ;
}
}
public void afterExpiration ( NativeWebRequest request , DeferredResult < ? > result ) throws Exception { }
} ) ;
return asyncResultLatch ;
}
protected DefaultMvcResult getMvcResult ( ServletRequest request ) {
return ( DefaultMvcResult ) request . getAttribute ( MockMvc . MVC_RESULT_ATTRIBUTE ) ;
}
@Override
@ -118,38 +126,4 @@ final class TestDispatcherServlet extends DispatcherServlet {
@@ -118,38 +126,4 @@ final class TestDispatcherServlet extends DispatcherServlet {
return mav ;
}
private final class TestCallableInterceptor implements CallableProcessingInterceptor {
private final CountDownLatch latch = new CountDownLatch ( 1 ) ;
private boolean await ( ) {
try {
return this . latch . await ( 5 , TimeUnit . SECONDS ) ;
}
catch ( InterruptedException e ) {
return false ;
}
}
public void preProcess ( NativeWebRequest request , Callable < ? > task ) { }
public void postProcess ( NativeWebRequest request , Callable < ? > task , Object concurrentResult ) {
this . latch . countDown ( ) ;
}
}
private final class TestDeferredResultInterceptor implements DeferredResultProcessingInterceptor {
private boolean wasInvoked ;
public void preProcess ( NativeWebRequest request , DeferredResult < ? > deferredResult ) {
this . wasInvoked = true ;
}
public void postProcess ( NativeWebRequest request , DeferredResult < ? > deferredResult , Object concurrentResult ) { }
public void afterExpiration ( NativeWebRequest request , DeferredResult < ? > deferredResult ) { }
}
}