diff --git a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java index c106cb8162..de9c519135 100644 --- a/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java +++ b/spring-context/src/main/java/org/springframework/context/annotation/ConfigurationClassPostProcessor.java @@ -59,6 +59,7 @@ import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; import org.springframework.beans.factory.support.BeanNameGenerator; import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.ApplicationStartupAware; import org.springframework.context.EnvironmentAware; import org.springframework.context.ResourceLoaderAware; @@ -518,6 +519,10 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo private static final String MAPPINGS_VARIABLE = "mappings"; + private static final String BEAN_DEFINITION_VARIABLE = "beanDefinition"; + + private static final String BEAN_NAME = "org.springframework.context.annotation.internalImportAwareAotProcessor"; + private final ConfigurableListableBeanFactory beanFactory; @@ -561,9 +566,12 @@ public class ConfigurationClassPostProcessor implements BeanDefinitionRegistryPo MAPPINGS_VARIABLE, HashMap.class); mappings.forEach((type, from) -> builder.addStatement("$L.put($S, $S)", MAPPINGS_VARIABLE, type, from)); - builder.addStatement("$L.addBeanPostProcessor(new $T($L))", - BEAN_FACTORY_VARIABLE, ImportAwareAotBeanPostProcessor.class, - MAPPINGS_VARIABLE); + builder.addStatement("$T $L = new $T($T.class)", RootBeanDefinition.class, + BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, ImportAwareAotBeanPostProcessor.class); + builder.addStatement("$L.getConstructorArgumentValues().addIndexedArgumentValue(0, $L)", + BEAN_DEFINITION_VARIABLE, MAPPINGS_VARIABLE); + builder.addStatement("$L.registerBeanDefinition($S, $L)", + BEAN_FACTORY_VARIABLE, BEAN_NAME, BEAN_DEFINITION_VARIABLE); return builder.build(); } diff --git a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java index 215a19c61e..fe97f4e81f 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/ConfigurationClassPostProcessorAotContributionTests.java @@ -16,6 +16,8 @@ package org.springframework.context.annotation; +import java.util.ArrayList; +import java.util.List; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -28,14 +30,20 @@ import org.springframework.aot.generate.MethodReference; import org.springframework.aot.hint.ResourcePatternHint; import org.springframework.aot.test.generator.compile.Compiled; import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.beans.BeansException; import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; +import org.springframework.beans.factory.config.BeanPostProcessor; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryInitializationCode; import org.springframework.beans.testfixture.beans.factory.generator.SimpleConfiguration; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.context.support.GenericApplicationContext; import org.springframework.context.testfixture.context.generator.annotation.ImportAwareConfiguration; import org.springframework.context.testfixture.context.generator.annotation.ImportConfiguration; import org.springframework.core.testfixture.aot.generate.TestGenerationContext; +import org.springframework.core.type.AnnotationMetadata; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; @@ -62,6 +70,10 @@ class ConfigurationClassPostProcessorAotContributionTests { this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(this.generationContext); } + @Test + void processAheadOfTimeWhenNoImportAwareConfigurationReturnsNull() { + assertThat(getContribution(SimpleConfiguration.class)).isNull(); + } @Test void applyToWhenHasImportAwareConfigurationRegistersBeanPostProcessorWithMapEntry() { @@ -69,12 +81,32 @@ class ConfigurationClassPostProcessorAotContributionTests { ImportConfiguration.class); contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); compile((initializer, compiled) -> { - DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory(); + GenericApplicationContext freshContext = new GenericApplicationContext(); + DefaultListableBeanFactory freshBeanFactory = freshContext.getDefaultListableBeanFactory(); + initializer.accept(freshBeanFactory); + freshContext.refresh(); + assertThat(freshBeanFactory.getBeanPostProcessors()).filteredOn(ImportAwareAotBeanPostProcessor.class::isInstance) + .singleElement().satisfies(postProcessor -> assertPostProcessorEntry(postProcessor, ImportAwareConfiguration.class, + ImportConfiguration.class)); + }); + } + + @Test + void applyToWhenHasImportAwareConfigurationRegistersBeanPostProcessorAfterApplicationContextAwareProcessor() { + BeanFactoryInitializationAotContribution contribution = getContribution( + ImportConfiguration.class); + contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode); + compile((initializer, compiled) -> { + GenericApplicationContext freshContext = new AnnotationConfigApplicationContext(); + DefaultListableBeanFactory freshBeanFactory = freshContext.getDefaultListableBeanFactory(); initializer.accept(freshBeanFactory); - ImportAwareAotBeanPostProcessor postProcessor = (ImportAwareAotBeanPostProcessor) freshBeanFactory - .getBeanPostProcessors().get(0); - assertPostProcessorEntry(postProcessor, ImportAwareConfiguration.class, - ImportConfiguration.class); + freshContext.registerBean(TestAwareCallbackConfiguration.class); + freshContext.refresh(); + TestAwareCallbackBean bean = freshContext.getBean(TestAwareCallbackBean.class); + assertThat(bean.instances).hasSize(2); + assertThat(bean.instances.get(0)).isEqualTo(freshContext); + assertThat(bean.instances.get(1)).isInstanceOfSatisfying(AnnotationMetadata.class, metadata -> + assertThat(metadata.getClassName()).isEqualTo(TestAwareCallbackConfiguration.class.getName())); }); } @@ -91,11 +123,6 @@ class ConfigurationClassPostProcessorAotContributionTests { + "ImportConfiguration.class")); } - @Test - void processAheadOfTimeWhenNoImportAwareConfigurationReturnsNull() { - assertThat(getContribution(SimpleConfiguration.class)).isNull(); - } - @Nullable private BeanFactoryInitializationAotContribution getContribution(Class type) { DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); @@ -105,6 +132,13 @@ class ConfigurationClassPostProcessorAotContributionTests { return postProcessor.processAheadOfTime(beanFactory); } + private void assertPostProcessorEntry(BeanPostProcessor postProcessor, + Class key, Class value) { + assertThat(postProcessor).extracting("importsMapping") + .asInstanceOf(InstanceOfAssertFactories.MAP) + .containsExactly(entry(key.getName(), value.getName())); + } + @SuppressWarnings("unchecked") private void compile(BiConsumer, Compiled> result) { MethodReference methodReference = this.beanFactoryInitializationCode @@ -122,11 +156,26 @@ class ConfigurationClassPostProcessorAotContributionTests { result.accept(compiled.getInstance(Consumer.class), compiled)); } - private void assertPostProcessorEntry(ImportAwareAotBeanPostProcessor postProcessor, - Class key, Class value) { - assertThat(postProcessor).extracting("importsMapping") - .asInstanceOf(InstanceOfAssertFactories.MAP) - .containsExactly(entry(key.getName(), value.getName())); + @Configuration(proxyBeanMethods = false) + @Import(TestAwareCallbackBean.class) + static class TestAwareCallbackConfiguration { + + } + + static class TestAwareCallbackBean implements ImportAware, ApplicationContextAware { + + private final List instances = new ArrayList<>(); + + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.instances.add(applicationContext); + } + + @Override + public void setImportMetadata(AnnotationMetadata importMetadata) { + this.instances.add(importMetadata); + } + } }