Browse Source

Register runtime hints for @Sql scripts

SqlScriptsTestExecutionListener now implements AotTestExecutionListener
in order to register run-time hints for SQL scripts used by test
classes and test methods annotated with @Sql if the configured or
detected SQL scripts are classpath resources.

Closes gh-29027
pull/29083/head
Sam Brannen 2 years ago
parent
commit
e57b5f1dfc
  1. 69
      spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java
  2. 6
      spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorTests.java

69
spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java

@ -17,14 +17,17 @@ @@ -17,14 +17,17 @@
package org.springframework.test.context.jdbc;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;
import javax.sql.DataSource;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.io.ByteArrayResource;
@ -35,6 +38,7 @@ import org.springframework.lang.NonNull; @@ -35,6 +38,7 @@ import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.TestContextAnnotationUtils;
import org.springframework.test.context.aot.AotTestExecutionListener;
import org.springframework.test.context.jdbc.Sql.ExecutionPhase;
import org.springframework.test.context.jdbc.SqlConfig.ErrorMode;
import org.springframework.test.context.jdbc.SqlConfig.TransactionMode;
@ -52,9 +56,11 @@ import org.springframework.util.Assert; @@ -52,9 +56,11 @@ import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.ResourceUtils;
import org.springframework.util.ReflectionUtils.MethodFilter;
import org.springframework.util.StringUtils;
import static org.springframework.util.ResourceUtils.CLASSPATH_URL_PREFIX;
/**
* {@code TestExecutionListener} that provides support for executing SQL
* {@link Sql#scripts scripts} and inlined {@link Sql#statements statements}
@ -90,18 +96,22 @@ import org.springframework.util.StringUtils; @@ -90,18 +96,22 @@ import org.springframework.util.StringUtils;
* @since 4.1
* @see Sql
* @see SqlConfig
* @see SqlMergeMode
* @see SqlGroup
* @see org.springframework.test.context.transaction.TestContextTransactionUtils
* @see org.springframework.test.context.transaction.TransactionalTestExecutionListener
* @see org.springframework.jdbc.datasource.init.ResourceDatabasePopulator
* @see org.springframework.jdbc.datasource.init.ScriptUtils
*/
public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener {
public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener implements AotTestExecutionListener {
private static final String SLASH = "/";
private static final Log logger = LogFactory.getLog(SqlScriptsTestExecutionListener.class);
private static final MethodFilter sqlMethodFilter = ReflectionUtils.USER_DECLARED_METHODS
.and(method -> AnnotatedElementUtils.hasAnnotation(method, Sql.class));
/**
* Returns {@code 5000}.
@ -129,6 +139,21 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen @@ -129,6 +139,21 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
executeSqlScripts(testContext, ExecutionPhase.AFTER_TEST_METHOD);
}
/**
* Process the supplied test class and its methods and register run-time
* hints for any SQL scripts configured or detected as classpath resources
* via {@link Sql @Sql}.
* @since 6.0
*/
@Override
public void processAheadOfTime(Class<?> testClass, RuntimeHints runtimeHints, ClassLoader classLoader) {
getSqlAnnotationsFor(testClass).forEach(sql ->
registerClasspathResources(runtimeHints, getScripts(sql, testClass, null, true)));
getSqlMethods(testClass).forEach(testMethod ->
getSqlAnnotationsFor(testMethod).forEach(sql ->
registerClasspathResources(runtimeHints, getScripts(sql, testClass, testMethod, false))));
}
/**
* Execute SQL scripts configured via {@link Sql @Sql} for the supplied
* {@link TestContext} and {@link ExecutionPhase}.
@ -226,8 +251,7 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen @@ -226,8 +251,7 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
mergedSqlConfig, executionPhase, testContext));
}
String[] scripts = getScripts(sql, testContext, classLevel);
scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts);
String[] scripts = getScripts(sql, testContext.getTestClass(), testContext.getTestMethod(), classLevel);
List<Resource> scriptResources = TestContextResourceUtils.convertToResourceList(
testContext.getApplicationContext(), scripts);
for (String stmt : sql.statements()) {
@ -321,31 +345,29 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen @@ -321,31 +345,29 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
return null;
}
private String[] getScripts(Sql sql, TestContext testContext, boolean classLevel) {
private String[] getScripts(Sql sql, Class<?> testClass, Method testMethod, boolean classLevel) {
String[] scripts = sql.scripts();
if (ObjectUtils.isEmpty(scripts) && ObjectUtils.isEmpty(sql.statements())) {
scripts = new String[] {detectDefaultScript(testContext, classLevel)};
scripts = new String[] {detectDefaultScript(testClass, testMethod, classLevel)};
}
return scripts;
return TestContextResourceUtils.convertToClasspathResourcePaths(testClass, scripts);
}
/**
* Detect a default SQL script by implementing the algorithm defined in
* {@link Sql#scripts}.
*/
private String detectDefaultScript(TestContext testContext, boolean classLevel) {
Class<?> clazz = testContext.getTestClass();
Method method = testContext.getTestMethod();
private String detectDefaultScript(Class<?> testClass, Method testMethod, boolean classLevel) {
String elementType = (classLevel ? "class" : "method");
String elementName = (classLevel ? clazz.getName() : method.toString());
String elementName = (classLevel ? testClass.getName() : testMethod.toString());
String resourcePath = ClassUtils.convertClassNameToResourcePath(clazz.getName());
String resourcePath = ClassUtils.convertClassNameToResourcePath(testClass.getName());
if (!classLevel) {
resourcePath += "." + method.getName();
resourcePath += "." + testMethod.getName();
}
resourcePath += ".sql";
String prefixedResourcePath = ResourceUtils.CLASSPATH_URL_PREFIX + SLASH + resourcePath;
String prefixedResourcePath = CLASSPATH_URL_PREFIX + SLASH + resourcePath;
ClassPathResource classPathResource = new ClassPathResource(resourcePath);
if (classPathResource.exists()) {
@ -364,4 +386,23 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen @@ -364,4 +386,23 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
}
}
private Stream<Method> getSqlMethods(Class<?> testClass) {
return Arrays.stream(ReflectionUtils.getUniqueDeclaredMethods(testClass, sqlMethodFilter));
}
private void registerClasspathResources(RuntimeHints runtimeHints, String... locations) {
Arrays.stream(locations)
.filter(location -> location.startsWith(CLASSPATH_URL_PREFIX))
.map(this::cleanClasspathResource)
.forEach(runtimeHints.resources()::registerPattern);
}
private String cleanClasspathResource(String location) {
location = location.substring(CLASSPATH_URL_PREFIX.length());
if (!location.startsWith(SLASH)) {
location = SLASH + location;
}
return location;
}
}

6
spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorTests.java

@ -191,6 +191,12 @@ class TestContextAotGeneratorTests extends AbstractAotTests { @@ -191,6 +191,12 @@ class TestContextAotGeneratorTests extends AbstractAotTests {
// @WebAppConfiguration(value = ...)
assertThat(resource().forResource("/META-INF/web-resources/resources/Spring.js")).accepts(runtimeHints);
assertThat(resource().forResource("/META-INF/web-resources/WEB-INF/views/home.jsp")).accepts(runtimeHints);
// @Sql(scripts = ...)
assertThat(resource().forResource("/org/springframework/test/context/jdbc/schema.sql"))
.accepts(runtimeHints);
assertThat(resource().forResource("/org/springframework/test/context/aot/samples/jdbc/SqlScriptsSpringJupiterTests.test.sql"))
.accepts(runtimeHints);
}
private static void assertReflectionRegistered(RuntimeHints runtimeHints, String type, MemberCategory memberCategory) {

Loading…
Cancel
Save