@ -17,14 +17,17 @@
@@ -17,14 +17,17 @@
package org.springframework.test.context.jdbc ;
import java.lang.reflect.Method ;
import java.util.Arrays ;
import java.util.List ;
import java.util.Set ;
import java.util.stream.Stream ;
import javax.sql.DataSource ;
import org.apache.commons.logging.Log ;
import org.apache.commons.logging.LogFactory ;
import org.springframework.aot.hint.RuntimeHints ;
import org.springframework.context.ApplicationContext ;
import org.springframework.core.annotation.AnnotatedElementUtils ;
import org.springframework.core.io.ByteArrayResource ;
@ -35,6 +38,7 @@ import org.springframework.lang.NonNull;
@@ -35,6 +38,7 @@ import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable ;
import org.springframework.test.context.TestContext ;
import org.springframework.test.context.TestContextAnnotationUtils ;
import org.springframework.test.context.aot.AotTestExecutionListener ;
import org.springframework.test.context.jdbc.Sql.ExecutionPhase ;
import org.springframework.test.context.jdbc.SqlConfig.ErrorMode ;
import org.springframework.test.context.jdbc.SqlConfig.TransactionMode ;
@ -52,9 +56,11 @@ import org.springframework.util.Assert;
@@ -52,9 +56,11 @@ import org.springframework.util.Assert;
import org.springframework.util.ClassUtils ;
import org.springframework.util.ObjectUtils ;
import org.springframework.util.ReflectionUtils ;
import org.springframework.util.ResourceUtils ;
import org.springframework.util.ReflectionUtils.MethodFilter ;
import org.springframework.util.StringUtils ;
import static org.springframework.util.ResourceUtils.CLASSPATH_URL_PREFIX ;
/ * *
* { @code TestExecutionListener } that provides support for executing SQL
* { @link Sql # scripts scripts } and inlined { @link Sql # statements statements }
@ -90,18 +96,22 @@ import org.springframework.util.StringUtils;
@@ -90,18 +96,22 @@ import org.springframework.util.StringUtils;
* @since 4 . 1
* @see Sql
* @see SqlConfig
* @see SqlMergeMode
* @see SqlGroup
* @see org . springframework . test . context . transaction . TestContextTransactionUtils
* @see org . springframework . test . context . transaction . TransactionalTestExecutionListener
* @see org . springframework . jdbc . datasource . init . ResourceDatabasePopulator
* @see org . springframework . jdbc . datasource . init . ScriptUtils
* /
public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener {
public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener implements AotTestExecutionListener {
private static final String SLASH = "/" ;
private static final Log logger = LogFactory . getLog ( SqlScriptsTestExecutionListener . class ) ;
private static final MethodFilter sqlMethodFilter = ReflectionUtils . USER_DECLARED_METHODS
. and ( method - > AnnotatedElementUtils . hasAnnotation ( method , Sql . class ) ) ;
/ * *
* Returns { @code 5000 } .
@ -129,6 +139,21 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
@@ -129,6 +139,21 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
executeSqlScripts ( testContext , ExecutionPhase . AFTER_TEST_METHOD ) ;
}
/ * *
* Process the supplied test class and its methods and register run - time
* hints for any SQL scripts configured or detected as classpath resources
* via { @link Sql @Sql } .
* @since 6 . 0
* /
@Override
public void processAheadOfTime ( Class < ? > testClass , RuntimeHints runtimeHints , ClassLoader classLoader ) {
getSqlAnnotationsFor ( testClass ) . forEach ( sql - >
registerClasspathResources ( runtimeHints , getScripts ( sql , testClass , null , true ) ) ) ;
getSqlMethods ( testClass ) . forEach ( testMethod - >
getSqlAnnotationsFor ( testMethod ) . forEach ( sql - >
registerClasspathResources ( runtimeHints , getScripts ( sql , testClass , testMethod , false ) ) ) ) ;
}
/ * *
* Execute SQL scripts configured via { @link Sql @Sql } for the supplied
* { @link TestContext } and { @link ExecutionPhase } .
@ -226,8 +251,7 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
@@ -226,8 +251,7 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
mergedSqlConfig , executionPhase , testContext ) ) ;
}
String [ ] scripts = getScripts ( sql , testContext , classLevel ) ;
scripts = TestContextResourceUtils . convertToClasspathResourcePaths ( testContext . getTestClass ( ) , scripts ) ;
String [ ] scripts = getScripts ( sql , testContext . getTestClass ( ) , testContext . getTestMethod ( ) , classLevel ) ;
List < Resource > scriptResources = TestContextResourceUtils . convertToResourceList (
testContext . getApplicationContext ( ) , scripts ) ;
for ( String stmt : sql . statements ( ) ) {
@ -321,31 +345,29 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
@@ -321,31 +345,29 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
return null ;
}
private String [ ] getScripts ( Sql sql , TestContext testContext , boolean classLevel ) {
private String [ ] getScripts ( Sql sql , Class < ? > testClass , Method testMethod , boolean classLevel ) {
String [ ] scripts = sql . scripts ( ) ;
if ( ObjectUtils . isEmpty ( scripts ) & & ObjectUtils . isEmpty ( sql . statements ( ) ) ) {
scripts = new String [ ] { detectDefaultScript ( testContext , classLevel ) } ;
scripts = new String [ ] { detectDefaultScript ( testClass , testMethod , classLevel ) } ;
}
return scripts ;
return TestContextResourceUtils . convertToClasspathResourcePaths ( testClass , scripts ) ;
}
/ * *
* Detect a default SQL script by implementing the algorithm defined in
* { @link Sql # scripts } .
* /
private String detectDefaultScript ( TestContext testContext , boolean classLevel ) {
Class < ? > clazz = testContext . getTestClass ( ) ;
Method method = testContext . getTestMethod ( ) ;
private String detectDefaultScript ( Class < ? > testClass , Method testMethod , boolean classLevel ) {
String elementType = ( classLevel ? "class" : "method" ) ;
String elementName = ( classLevel ? clazz . getName ( ) : m ethod. toString ( ) ) ;
String elementName = ( classLevel ? testClass . getName ( ) : testMethod . toString ( ) ) ;
String resourcePath = ClassUtils . convertClassNameToResourcePath ( clazz . getName ( ) ) ;
String resourcePath = ClassUtils . convertClassNameToResourcePath ( testClass . getName ( ) ) ;
if ( ! classLevel ) {
resourcePath + = "." + m ethod. getName ( ) ;
resourcePath + = "." + testM ethod. getName ( ) ;
}
resourcePath + = ".sql" ;
String prefixedResourcePath = ResourceUtils . CLASSPATH_URL_PREFIX + SLASH + resourcePath ;
String prefixedResourcePath = CLASSPATH_URL_PREFIX + SLASH + resourcePath ;
ClassPathResource classPathResource = new ClassPathResource ( resourcePath ) ;
if ( classPathResource . exists ( ) ) {
@ -364,4 +386,23 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
@@ -364,4 +386,23 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
}
}
private Stream < Method > getSqlMethods ( Class < ? > testClass ) {
return Arrays . stream ( ReflectionUtils . getUniqueDeclaredMethods ( testClass , sqlMethodFilter ) ) ;
}
private void registerClasspathResources ( RuntimeHints runtimeHints , String . . . locations ) {
Arrays . stream ( locations )
. filter ( location - > location . startsWith ( CLASSPATH_URL_PREFIX ) )
. map ( this : : cleanClasspathResource )
. forEach ( runtimeHints . resources ( ) : : registerPattern ) ;
}
private String cleanClasspathResource ( String location ) {
location = location . substring ( CLASSPATH_URL_PREFIX . length ( ) ) ;
if ( ! location . startsWith ( SLASH ) ) {
location = SLASH + location ;
}
return location ;
}
}