diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java index 6986183d14..13756a8fc6 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGenerator.java @@ -27,7 +27,6 @@ import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.MethodReference; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.RegisteredBean; -import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.javapoet.ClassName; import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; @@ -66,10 +65,6 @@ class BeanDefinitionMethodGenerator { RegisteredBean registeredBean, @Nullable String currentPropertyName, List aotContributions) { - RootBeanDefinition mbd = registeredBean.getMergedBeanDefinition(); - if (mbd.getInstanceSupplier() != null && aotContributions.isEmpty()) { - throw new IllegalArgumentException("Code generation is not supported for bean definitions declaring an instance supplier callback : " + mbd); - } this.methodGeneratorFactory = methodGeneratorFactory; this.registeredBean = registeredBean; this.currentPropertyName = currentPropertyName; diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java index e51aebd7f0..77952134cc 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/aot/DefaultBeanRegistrationCodeFragments.java @@ -73,6 +73,10 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme @Override public ClassName getTarget(RegisteredBean registeredBean) { + if (hasInstanceSupplier()) { + throw new IllegalStateException("Default code generation is not supported for bean definitions " + + "declaring an instance supplier callback: " + registeredBean.getMergedBeanDefinition()); + } Class target = extractDeclaringClass(registeredBean.getBeanType(), this.constructorOrFactoryMethod.get()); while (target.getName().startsWith("java.") && registeredBean.isInnerBean()) { RegisteredBean parent = registeredBean.getParent(); @@ -224,7 +228,10 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme @Override public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { - + if (hasInstanceSupplier()) { + throw new IllegalStateException("Default code generation is not supported for bean definitions declaring " + + "an instance supplier callback: " + this.registeredBean.getMergedBeanDefinition()); + } return new InstanceSupplierCodeGenerator(generationContext, beanRegistrationCode.getClassName(), beanRegistrationCode.getMethods(), allowDirectSupplierShortcut) .generateCode(this.registeredBean,this.constructorOrFactoryMethod.get()); @@ -239,4 +246,8 @@ class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragme return code.build(); } + private boolean hasInstanceSupplier() { + return this.registeredBean.getMergedBeanDefinition().getInstanceSupplier() != null; + } + } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 3d490ec1a2..0c5bcc7343 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -57,13 +57,14 @@ import org.springframework.core.test.tools.CompileWithForkedClassLoader; import org.springframework.core.test.tools.Compiled; import org.springframework.core.test.tools.SourceFile; import org.springframework.core.test.tools.TestCompiler; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterizedTypeName; import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** * Tests for {@link BeanDefinitionMethodGenerator} and @@ -172,22 +173,6 @@ class BeanDefinitionMethodGeneratorTests { }); } - @Test // gh-29556 - void generateBeanDefinitionMethodGeneratesMethodWithInstanceSupplier() { - RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class, TestBean::new)); - BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( - this.methodGeneratorFactory, registeredBean, null, - List.of((generationContext, beanRegistrationCode) -> { })); - MethodReference method = generator.generateBeanDefinitionMethod( - this.generationContext, this.beanRegistrationsCode); - compile(method, (actual, compiled) -> { - SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); - assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); - assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); - assertThat(actual).isInstanceOf(RootBeanDefinition.class); - }); - } - @Test void generateBeanDefinitionMethodWhenHasInnerClassTargetMethodGeneratesMethod() { this.beanFactory.registerBeanDefinition("testBeanConfiguration", new RootBeanDefinition( @@ -591,12 +576,82 @@ class BeanDefinitionMethodGeneratorTests { testBeanDefinitionMethodInCurrentFile(Boolean.class, beanDefinition); } - @Test // gh-29556 - void throwExceptionWithInstanceSupplierWithoutAotContribution() { + @Test + void generateBeanDefinitionMethodWhenInstanceSupplierWithNoCustomization() { + RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class, TestBean::new)); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + List.of()); + assertThatIllegalStateException().isThrownBy(() -> generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode)).withMessageStartingWith( + "Default code generation is not supported for bean definitions declaring an instance supplier callback"); + } + + @Test + void generateBeanDefinitionMethodWhenInstanceSupplierWithOnlyCustomTarget() { + BeanRegistrationAotContribution aotContribution = BeanRegistrationAotContribution.withCustomCodeFragments( + defaultCodeFragments -> new BeanRegistrationCodeFragmentsDecorator(defaultCodeFragments) { + @Override + public ClassName getTarget(RegisteredBean registeredBean) { + return ClassName.get(TestBean.class); + } + }); RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class, TestBean::new)); - assertThatIllegalArgumentException().isThrownBy(() -> new BeanDefinitionMethodGenerator( + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( this.methodGeneratorFactory, registeredBean, null, - Collections.emptyList())); + List.of(aotContribution)); + assertThatIllegalStateException().isThrownBy(() -> generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode)).withMessageStartingWith( + "Default code generation is not supported for bean definitions declaring an instance supplier callback"); + } + + @Test + void generateBeanDefinitionMethodWhenInstanceSupplierWithOnlyCustomInstanceSupplier() { + BeanRegistrationAotContribution aotContribution = BeanRegistrationAotContribution.withCustomCodeFragments( + defaultCodeFragments -> new BeanRegistrationCodeFragmentsDecorator(defaultCodeFragments) { + @Override + public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { + return CodeBlock.of("// custom"); + } + }); + RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class, TestBean::new)); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + List.of(aotContribution)); + assertThatIllegalStateException().isThrownBy(() -> generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode)).withMessageStartingWith( + "Default code generation is not supported for bean definitions declaring an instance supplier callback"); + } + + @Test + void generateBeanDefinitionMethodWhenInstanceSupplierWithCustomInstanceSupplierAndCustomTarget() { + BeanRegistrationAotContribution aotContribution = BeanRegistrationAotContribution.withCustomCodeFragments( + defaultCodeFragments -> new BeanRegistrationCodeFragmentsDecorator(defaultCodeFragments) { + + @Override + public ClassName getTarget(RegisteredBean registeredBean) { + return ClassName.get(TestBean.class); + } + + @Override + public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { + return CodeBlock.of("$T::new", TestBean.class); + } + }); + RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class, TestBean::new)); + BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator( + this.methodGeneratorFactory, registeredBean, null, + List.of(aotContribution)); + MethodReference method = generator.generateBeanDefinitionMethod( + this.generationContext, this.beanRegistrationsCode); + compile(method, (actual, compiled) -> { + SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions"); + assertThat(sourceFile).contains("Get the bean definition for 'testBean'"); + assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)"); + assertThat(actual).isInstanceOf(RootBeanDefinition.class); + }); } private void testBeanDefinitionMethodInCurrentFile(Class targetType, RootBeanDefinition beanDefinition) {