@ -17,19 +17,14 @@
@@ -17,19 +17,14 @@
package org.springframework.web.servlet.function ;
import java.io.IOException ;
import java.time.Duration ;
import java.util.concurrent.CompletableFuture ;
import java.util.concurrent.atomic.AtomicInteger ;
import java.util.concurrent.CompletionException ;
import java.util.function.Function ;
import javax.servlet.AsyncContext ;
import javax.servlet.AsyncListener ;
import javax.servlet.ServletContext ;
import javax.servlet.ServletException ;
import javax.servlet.ServletRequest ;
import javax.servlet.ServletResponse ;
import javax.servlet.http.Cookie ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletRequestWrapper ;
import javax.servlet.http.HttpServletResponse ;
import org.reactivestreams.Publisher ;
@ -42,6 +37,10 @@ import org.springframework.lang.Nullable;
@@ -42,6 +37,10 @@ import org.springframework.lang.Nullable;
import org.springframework.util.Assert ;
import org.springframework.util.ClassUtils ;
import org.springframework.util.MultiValueMap ;
import org.springframework.web.context.request.async.AsyncWebRequest ;
import org.springframework.web.context.request.async.DeferredResult ;
import org.springframework.web.context.request.async.WebAsyncManager ;
import org.springframework.web.context.request.async.WebAsyncUtils ;
import org.springframework.web.servlet.ModelAndView ;
/ * *
@ -59,9 +58,13 @@ final class AsyncServerResponse extends ErrorHandlingServerResponse {
@@ -59,9 +58,13 @@ final class AsyncServerResponse extends ErrorHandlingServerResponse {
private final CompletableFuture < ServerResponse > futureResponse ;
@Nullable
private final Duration timeout ;
private AsyncServerResponse ( CompletableFuture < ServerResponse > futureResponse ) {
private AsyncServerResponse ( CompletableFuture < ServerResponse > futureResponse , @Nullable Duration timeout ) {
this . futureResponse = futureResponse ;
this . timeout = timeout ;
}
@Override
@ -96,44 +99,62 @@ final class AsyncServerResponse extends ErrorHandlingServerResponse {
@@ -96,44 +99,62 @@ final class AsyncServerResponse extends ErrorHandlingServerResponse {
@Nullable
@Override
public ModelAndView writeTo ( HttpServletRequest request , HttpServletResponse response , Context context ) {
public ModelAndView writeTo ( HttpServletRequest request , HttpServletResponse response , Context context )
throws ServletException , IOException {
SharedAsyncContextHttpServletRequest sharedRequest = new SharedAsyncContextHttpServletRequest ( request ) ;
AsyncContext asyncContext = sharedRequest . startAsync ( request , response ) ;
this . futureResponse . whenComplete ( ( futureResponse , futureThrowable ) - > {
try {
if ( futureResponse ! = null ) {
ModelAndView mav = futureResponse . writeTo ( sharedRequest , response , context ) ;
Assert . state ( mav = = null , "Asynchronous, rendering ServerResponse implementations are not " +
"supported in WebMvc.fn. Please use WebFlux.fn instead." ) ;
}
else if ( futureThrowable ! = null ) {
handleError ( futureThrowable , request , response , context ) ;
}
}
catch ( Throwable throwable ) {
try {
handleError ( throwable , request , response , context ) ;
}
catch ( ServletException | IOException ex ) {
logger . warn ( "Asynchronous execution resulted in exception" , ex ) ;
writeAsync ( request , response , createDeferredResult ( ) ) ;
return null ;
}
static void writeAsync ( HttpServletRequest request , HttpServletResponse response , DeferredResult < ? > deferredResult )
throws ServletException , IOException {
WebAsyncManager asyncManager = WebAsyncUtils . getAsyncManager ( request ) ;
AsyncWebRequest asyncWebRequest = WebAsyncUtils . createAsyncWebRequest ( request , response ) ;
asyncManager . setAsyncWebRequest ( asyncWebRequest ) ;
try {
asyncManager . startDeferredResultProcessing ( deferredResult ) ;
}
catch ( IOException | ServletException ex ) {
throw ex ;
}
catch ( Exception ex ) {
throw new ServletException ( "Async processing failed" , ex ) ;
}
}
private DeferredResult < ServerResponse > createDeferredResult ( ) {
DeferredResult < ServerResponse > result ;
if ( this . timeout ! = null ) {
result = new DeferredResult < > ( this . timeout . toMillis ( ) ) ;
}
else {
result = new DeferredResult < > ( ) ;
}
this . futureResponse . handle ( ( value , ex ) - > {
if ( ex ! = null ) {
if ( ex instanceof CompletionException & & ex . getCause ( ) ! = null ) {
ex = ex . getCause ( ) ;
}
result . setErrorResult ( ex ) ;
}
finally {
asyncContext . complete ( ) ;
else {
result . setResult ( value ) ;
}
return null ;
} ) ;
return null ;
return result ;
}
@SuppressWarnings ( { "unchecked" } )
public static ServerResponse create ( Object o ) {
public static ServerResponse create ( Object o , @Nullable Duration timeout ) {
Assert . notNull ( o , "Argument to async must not be null" ) ;
if ( o instanceof CompletableFuture ) {
CompletableFuture < ServerResponse > futureResponse = ( CompletableFuture < ServerResponse > ) o ;
return new AsyncServerResponse ( futureResponse ) ;
return new AsyncServerResponse ( futureResponse , timeout ) ;
}
else if ( reactiveStreamsPresent ) {
ReactiveAdapterRegistry registry = ReactiveAdapterRegistry . getSharedInstance ( ) ;
@ -144,7 +165,7 @@ final class AsyncServerResponse extends ErrorHandlingServerResponse {
@@ -144,7 +165,7 @@ final class AsyncServerResponse extends ErrorHandlingServerResponse {
if ( futureAdapter ! = null ) {
CompletableFuture < ServerResponse > futureResponse =
( CompletableFuture < ServerResponse > ) futureAdapter . fromPublisher ( publisher ) ;
return new AsyncServerResponse ( futureResponse ) ;
return new AsyncServerResponse ( futureResponse , timeout ) ;
}
}
}
@ -152,150 +173,4 @@ final class AsyncServerResponse extends ErrorHandlingServerResponse {
@@ -152,150 +173,4 @@ final class AsyncServerResponse extends ErrorHandlingServerResponse {
}
/ * *
* HttpServletRequestWrapper that shares its AsyncContext between this
* AsyncServerResponse class and other , subsequent ServerResponse
* implementations , keeping track of how many contexts where
* started with startAsync ( ) . This way , we make sure that
* { @link AsyncContext # complete ( ) } only completes for the response that
* finishes last , and is not closed prematurely .
* /
private static final class SharedAsyncContextHttpServletRequest extends HttpServletRequestWrapper {
private final AsyncContext asyncContext ;
private final AtomicInteger startedContexts ;
public SharedAsyncContextHttpServletRequest ( HttpServletRequest request ) {
super ( request ) ;
this . asyncContext = request . startAsync ( ) ;
this . startedContexts = new AtomicInteger ( 0 ) ;
}
private SharedAsyncContextHttpServletRequest ( HttpServletRequest request , AsyncContext asyncContext ,
AtomicInteger startedContexts ) {
super ( request ) ;
this . asyncContext = asyncContext ;
this . startedContexts = startedContexts ;
}
@Override
public AsyncContext startAsync ( ) throws IllegalStateException {
this . startedContexts . incrementAndGet ( ) ;
return new SharedAsyncContext ( this . asyncContext , this , this . asyncContext . getResponse ( ) ,
this . startedContexts ) ;
}
@Override
public AsyncContext startAsync ( ServletRequest servletRequest , ServletResponse servletResponse )
throws IllegalStateException {
this . startedContexts . incrementAndGet ( ) ;
SharedAsyncContextHttpServletRequest sharedRequest ;
if ( servletRequest instanceof SharedAsyncContextHttpServletRequest ) {
sharedRequest = ( SharedAsyncContextHttpServletRequest ) servletRequest ;
}
else {
sharedRequest = new SharedAsyncContextHttpServletRequest ( ( HttpServletRequest ) servletRequest ,
this . asyncContext , this . startedContexts ) ;
}
return new SharedAsyncContext ( this . asyncContext , sharedRequest , servletResponse , this . startedContexts ) ;
}
@Override
public AsyncContext getAsyncContext ( ) {
return new SharedAsyncContext ( this . asyncContext , this , this . asyncContext . getResponse ( ) , this . startedContexts ) ;
}
private static final class SharedAsyncContext implements AsyncContext {
private final AsyncContext delegate ;
private final AtomicInteger openContexts ;
private final ServletRequest request ;
private final ServletResponse response ;
public SharedAsyncContext ( AsyncContext delegate , SharedAsyncContextHttpServletRequest request ,
ServletResponse response , AtomicInteger usageCount ) {
this . delegate = delegate ;
this . request = request ;
this . response = response ;
this . openContexts = usageCount ;
}
@Override
public void complete ( ) {
if ( this . openContexts . decrementAndGet ( ) = = 0 ) {
this . delegate . complete ( ) ;
}
}
@Override
public ServletRequest getRequest ( ) {
return this . request ;
}
@Override
public ServletResponse getResponse ( ) {
return this . response ;
}
@Override
public boolean hasOriginalRequestAndResponse ( ) {
return this . delegate . hasOriginalRequestAndResponse ( ) ;
}
@Override
public void dispatch ( ) {
this . delegate . dispatch ( ) ;
}
@Override
public void dispatch ( String path ) {
this . delegate . dispatch ( path ) ;
}
@Override
public void dispatch ( ServletContext context , String path ) {
this . delegate . dispatch ( context , path ) ;
}
@Override
public void start ( Runnable run ) {
this . delegate . start ( run ) ;
}
@Override
public void addListener ( AsyncListener listener ) {
this . delegate . addListener ( listener ) ;
}
@Override
public void addListener ( AsyncListener listener ,
ServletRequest servletRequest ,
ServletResponse servletResponse ) {
this . delegate . addListener ( listener , servletRequest , servletResponse ) ;
}
@Override
public < T extends AsyncListener > T createListener ( Class < T > clazz ) throws ServletException {
return this . delegate . createListener ( clazz ) ;
}
@Override
public void setTimeout ( long timeout ) {
this . delegate . setTimeout ( timeout ) ;
}
@Override
public long getTimeout ( ) {
return this . delegate . getTimeout ( ) ;
}
}
}
}