@ -17,6 +17,8 @@
@@ -17,6 +17,8 @@
package org.springframework.orm.jpa.support ;
import java.util.concurrent.Callable ;
import java.util.concurrent.CountDownLatch ;
import java.util.concurrent.TimeUnit ;
import java.util.concurrent.atomic.AtomicInteger ;
import javax.persistence.EntityManager ;
@ -71,6 +73,8 @@ public class OpenEntityManagerInViewTests {
@@ -71,6 +73,8 @@ public class OpenEntityManagerInViewTests {
private ServletWebRequest webRequest ;
private final TestTaskExecutor taskExecutor = new TestTaskExecutor ( ) ;
@BeforeEach
public void setUp ( ) {
@ -144,10 +148,13 @@ public class OpenEntityManagerInViewTests {
@@ -144,10 +148,13 @@ public class OpenEntityManagerInViewTests {
AsyncWebRequest asyncWebRequest = new StandardServletAsyncWebRequest ( this . request , this . response ) ;
WebAsyncManager asyncManager = WebAsyncUtils . getAsyncManager ( this . webRequest ) ;
asyncManager . setTaskExecutor ( new SyncTaskExecutor ( ) ) ;
asyncManager . setTaskExecutor ( this . taskExecutor ) ;
asyncManager . setAsyncWebRequest ( asyncWebRequest ) ;
asyncManager . startCallableProcessing ( ( Callable < String > ) ( ) - > "anything" ) ;
this . taskExecutor . await ( ) ;
assertThat ( asyncManager . getConcurrentResult ( ) ) . as ( "Concurrent result " ) . isEqualTo ( "anything" ) ;
interceptor . afterConcurrentHandlingStarted ( this . webRequest ) ;
assertThat ( TransactionSynchronizationManager . hasResource ( factory ) ) . isFalse ( ) ;
@ -198,10 +205,13 @@ public class OpenEntityManagerInViewTests {
@@ -198,10 +205,13 @@ public class OpenEntityManagerInViewTests {
AsyncWebRequest asyncWebRequest = new StandardServletAsyncWebRequest ( this . request , this . response ) ;
WebAsyncManager asyncManager = WebAsyncUtils . getAsyncManager ( this . request ) ;
asyncManager . setTaskExecutor ( new SyncTaskExecutor ( ) ) ;
asyncManager . setTaskExecutor ( this . taskExecutor ) ;
asyncManager . setAsyncWebRequest ( asyncWebRequest ) ;
asyncManager . startCallableProcessing ( ( Callable < String > ) ( ) - > "anything" ) ;
this . taskExecutor . await ( ) ;
assertThat ( asyncManager . getConcurrentResult ( ) ) . as ( "Concurrent result " ) . isEqualTo ( "anything" ) ;
interceptor . afterConcurrentHandlingStarted ( this . webRequest ) ;
assertThat ( TransactionSynchronizationManager . hasResource ( this . factory ) ) . isFalse ( ) ;
@ -235,10 +245,13 @@ public class OpenEntityManagerInViewTests {
@@ -235,10 +245,13 @@ public class OpenEntityManagerInViewTests {
AsyncWebRequest asyncWebRequest = new StandardServletAsyncWebRequest ( this . request , this . response ) ;
WebAsyncManager asyncManager = WebAsyncUtils . getAsyncManager ( this . request ) ;
asyncManager . setTaskExecutor ( new SyncTaskExecutor ( ) ) ;
asyncManager . setTaskExecutor ( this . taskExecutor ) ;
asyncManager . setAsyncWebRequest ( asyncWebRequest ) ;
asyncManager . startCallableProcessing ( ( Callable < String > ) ( ) - > "anything" ) ;
this . taskExecutor . await ( ) ;
assertThat ( asyncManager . getConcurrentResult ( ) ) . as ( "Concurrent result " ) . isEqualTo ( "anything" ) ;
interceptor . afterConcurrentHandlingStarted ( this . webRequest ) ;
assertThat ( TransactionSynchronizationManager . hasResource ( this . factory ) ) . isFalse ( ) ;
@ -360,10 +373,13 @@ public class OpenEntityManagerInViewTests {
@@ -360,10 +373,13 @@ public class OpenEntityManagerInViewTests {
given ( asyncWebRequest . isAsyncStarted ( ) ) . willReturn ( true ) ;
WebAsyncManager asyncManager = WebAsyncUtils . getAsyncManager ( this . request ) ;
asyncManager . setTaskExecutor ( new SyncTaskExecutor ( ) ) ;
asyncManager . setTaskExecutor ( this . taskExecutor ) ;
asyncManager . setAsyncWebRequest ( asyncWebRequest ) ;
asyncManager . startCallableProcessing ( ( Callable < String > ) ( ) - > "anything" ) ;
this . taskExecutor . await ( ) ;
assertThat ( asyncManager . getConcurrentResult ( ) ) . as ( "Concurrent result " ) . isEqualTo ( "anything" ) ;
assertThat ( TransactionSynchronizationManager . hasResource ( factory ) ) . isFalse ( ) ;
assertThat ( TransactionSynchronizationManager . hasResource ( factory2 ) ) . isFalse ( ) ;
filter2 . doFilter ( this . request , this . response , filterChain3 ) ;
@ -398,11 +414,26 @@ public class OpenEntityManagerInViewTests {
@@ -398,11 +414,26 @@ public class OpenEntityManagerInViewTests {
@SuppressWarnings ( "serial" )
private static class SyncTaskExecutor extends SimpleAsyncTaskExecutor {
private static class TestTaskExecutor extends SimpleAsyncTaskExecutor {
private final CountDownLatch latch = new CountDownLatch ( 1 ) ;
@Override
public void execute ( Runnable task , long startTimeout ) {
task . run ( ) ;
Runnable decoratedTask = ( ) - > {
try {
task . run ( ) ;
}
finally {
latch . countDown ( ) ;
}
} ;
super . execute ( decoratedTask , startTimeout ) ;
}
void await ( ) throws InterruptedException {
this . latch . await ( 5 , TimeUnit . SECONDS ) ;
}
}
}