From 26054fd3d4aba89474f2173c7be44db6937aaf51 Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Thu, 21 Apr 2022 17:01:40 +0200 Subject: [PATCH] AOT contribution for @PersistenceContext and @PersistenceUnit Closes gh-28364 --- .../factory/generator/BeanFieldGenerator.java | 65 +++++ .../factory/generator/InjectionGenerator.java | 27 +-- .../generator/BeanFieldGeneratorTests.java | 78 ++++++ .../javapoet/support/MultiStatement.java | 10 + .../javapoet/support/MultiStatementTests.java | 13 + spring-orm/spring-orm.gradle | 1 + ...ersistenceAnnotationBeanPostProcessor.java | 94 +++++++- ...tenceAnnotationBeanPostProcessorTests.java | 225 ++++++++++++++++++ 8 files changed, 491 insertions(+), 22 deletions(-) create mode 100644 spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFieldGenerator.java create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanFieldGeneratorTests.java create mode 100644 spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorTests.java diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFieldGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFieldGenerator.java new file mode 100644 index 0000000000..224a7d7a93 --- /dev/null +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFieldGenerator.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; + +import org.springframework.aot.generator.ProtectedAccess.Options; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.support.MultiStatement; +import org.springframework.util.ReflectionUtils; + +/** + * Support for generating {@link Field} access. + * + * @author Stephane Nicoll + * @since 6.0 + */ +public class BeanFieldGenerator { + + /** + * The {@link Options} to use to access a field. + */ + public static final Options FIELD_OPTIONS = Options.defaults() + .useReflection(member -> Modifier.isPrivate(member.getModifiers())).build(); + + + /** + * Generate the necessary code to set the specified field. Use reflection + * using {@link ReflectionUtils} if necessary. + * @param field the field to set + * @param value a code representation of the field value + * @return the code to set the specified field + */ + public MultiStatement generateSetValue(String target, Field field, CodeBlock value) { + MultiStatement statement = new MultiStatement(); + boolean useReflection = Modifier.isPrivate(field.getModifiers()); + if (useReflection) { + String fieldName = String.format("%sField", field.getName()); + statement.addStatement("$T $L = $T.findField($T.class, $S)", Field.class, fieldName, ReflectionUtils.class, + field.getDeclaringClass(), field.getName()); + statement.addStatement("$T.makeAccessible($L)", ReflectionUtils.class, fieldName); + statement.addStatement("$T.setField($L, $L, $L)", ReflectionUtils.class, fieldName, target, value); + } + else { + statement.addStatement("$L.$L = $L", target, field.getName(), value); + } + return statement; + } + +} diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/generator/InjectionGenerator.java b/spring-beans/src/main/java/org/springframework/beans/factory/generator/InjectionGenerator.java index eb4f7cb8c1..625218f261 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/generator/InjectionGenerator.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/generator/InjectionGenerator.java @@ -34,7 +34,6 @@ import org.springframework.beans.factory.generator.config.BeanDefinitionRegistra import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock.Builder; import org.springframework.util.ClassUtils; -import org.springframework.util.ReflectionUtils; /** * Generate the necessary code to {@link #generateInstantiation(Executable) @@ -53,14 +52,13 @@ import org.springframework.util.ReflectionUtils; */ public class InjectionGenerator { - private static final Options FIELD_INJECTION_OPTIONS = Options.defaults() - .useReflection(member -> Modifier.isPrivate(member.getModifiers())).build(); - private static final Options METHOD_INJECTION_OPTIONS = Options.defaults() .useReflection(member -> false).build(); private final BeanParameterGenerator parameterGenerator = new BeanParameterGenerator(); + private final BeanFieldGenerator fieldGenerator = new BeanFieldGenerator(); + /** * Generate the necessary code to instantiate an object using the specified @@ -110,7 +108,7 @@ public class InjectionGenerator { return METHOD_INJECTION_OPTIONS; } if (member instanceof Field) { - return FIELD_INJECTION_OPTIONS; + return BeanFieldGenerator.FIELD_OPTIONS; } throw new IllegalArgumentException("Could not handle member " + member); } @@ -230,24 +228,13 @@ public class InjectionGenerator { code.add("instanceContext.field($S", injectionPoint.getName()); code.add(")\n").indent().indent(); if (required) { - code.add(".invoke(beanFactory, (attributes) ->"); - } - else { - code.add(".resolve(beanFactory, false).ifResolved((attributes) ->"); - } - boolean hasAssignment = Modifier.isPrivate(injectionPoint.getModifiers()); - if (hasAssignment) { - code.beginControlFlow(""); - String fieldName = String.format("%sField", injectionPoint.getName()); - code.addStatement("$T $L = $T.findField($T.class, $S)", Field.class, fieldName, ReflectionUtils.class, - injectionPoint.getDeclaringClass(), injectionPoint.getName()); - code.addStatement("$T.makeAccessible($L)", ReflectionUtils.class, fieldName); - code.addStatement("$T.setField($L, bean, attributes.get(0))", ReflectionUtils.class, fieldName); - code.unindent().add("}"); + code.add(".invoke(beanFactory, "); } else { - code.add(" bean.$L = attributes.get(0)", injectionPoint.getName()); + code.add(".resolve(beanFactory, false).ifResolved("); } + code.add(this.fieldGenerator.generateSetValue("bean", injectionPoint, + CodeBlock.of("attributes.get(0)")).toLambdaBody("(attributes) ->")); code.add(")").unindent().unindent(); return code.build(); } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanFieldGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanFieldGeneratorTests.java new file mode 100644 index 0000000000..36e55456fd --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanFieldGeneratorTests.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.beans.factory.generator; + +import java.lang.reflect.Field; + +import org.junit.jupiter.api.Test; + +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.support.CodeSnippet; +import org.springframework.javapoet.support.MultiStatement; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link BeanFieldGenerator}. + * + * @author Stephane Nicoll + */ +class BeanFieldGeneratorTests { + + private final BeanFieldGenerator generator = new BeanFieldGenerator(); + + @Test + void generateSetFieldWithPublicField() { + MultiStatement statement = this.generator.generateSetValue("bean", + field(SampleBean.class, "one"), CodeBlock.of("$S", "test")); + assertThat(CodeSnippet.process(statement.toCodeBlock())).isEqualTo(""" + bean.one = "test"; + """); + } + + @Test + void generateSetFieldWithPrivateField() { + MultiStatement statement = this.generator.generateSetValue("example", + field(SampleBean.class, "two"), CodeBlock.of("42")); + CodeSnippet code = CodeSnippet.of(statement.toCodeBlock()); + assertThat(code.getSnippet()).isEqualTo(""" + Field twoField = ReflectionUtils.findField(BeanFieldGeneratorTests.SampleBean.class, "two"); + ReflectionUtils.makeAccessible(twoField); + ReflectionUtils.setField(twoField, example, 42); + """); + assertThat(code.hasImport(ReflectionUtils.class)).isTrue(); + assertThat(code.hasImport(BeanFieldGeneratorTests.class)).isTrue(); + } + + + private Field field(Class type, String name) { + Field field = ReflectionUtils.findField(type, name); + assertThat(field).isNotNull(); + return field; + } + + + public static class SampleBean { + + public String one; + + private int two; + + } + +} diff --git a/spring-core/src/main/java/org/springframework/javapoet/support/MultiStatement.java b/spring-core/src/main/java/org/springframework/javapoet/support/MultiStatement.java index 2fbf56df1c..e566e5310f 100644 --- a/spring-core/src/main/java/org/springframework/javapoet/support/MultiStatement.java +++ b/spring-core/src/main/java/org/springframework/javapoet/support/MultiStatement.java @@ -44,6 +44,16 @@ public final class MultiStatement { return this.statements.isEmpty(); } + /** + * Add the statements defined in the specified multi statement to this instance. + * @param multiStatement the statements to add + * @return {@code this}, to facilitate method chaining + */ + public MultiStatement add(MultiStatement multiStatement) { + this.statements.addAll(multiStatement.statements); + return this; + } + /** * Add the specified {@link CodeBlock codeblock} rendered as-is. * @param codeBlock the code block to add diff --git a/spring-core/src/test/java/org/springframework/javapoet/support/MultiStatementTests.java b/spring-core/src/test/java/org/springframework/javapoet/support/MultiStatementTests.java index fdbd050346..6475ab7c10 100644 --- a/spring-core/src/test/java/org/springframework/javapoet/support/MultiStatementTests.java +++ b/spring-core/src/test/java/org/springframework/javapoet/support/MultiStatementTests.java @@ -150,4 +150,17 @@ class MultiStatementTests { "}"); } + @Test + void addWithAnotherMultiStatement() { + MultiStatement statements = new MultiStatement(); + statements.addStatement(CodeBlock.of("test.invoke()")); + MultiStatement another = new MultiStatement(); + another.addStatement(CodeBlock.of("test.another()")); + statements.add(another); + assertThat(statements.toCodeBlock().toString()).isEqualTo(""" + test.invoke(); + test.another(); + """); + } + } diff --git a/spring-orm/spring-orm.gradle b/spring-orm/spring-orm.gradle index d7d2bebf52..c2c5a4afc9 100644 --- a/spring-orm/spring-orm.gradle +++ b/spring-orm/spring-orm.gradle @@ -11,6 +11,7 @@ dependencies { optional("org.eclipse.persistence:org.eclipse.persistence.jpa") optional("org.hibernate:hibernate-core-jakarta") optional("jakarta.servlet:jakarta.servlet-api") + testImplementation(project(":spring-core-test")) testImplementation(testFixtures(project(":spring-beans"))) testImplementation(testFixtures(project(":spring-context"))) testImplementation(testFixtures(project(":spring-core"))) diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java index 66971589a3..d61ee63246 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,11 +19,13 @@ package org.springframework.orm.jpa.support; import java.beans.PropertyDescriptor; import java.io.Serializable; import java.lang.reflect.AnnotatedElement; +import java.lang.reflect.Field; import java.lang.reflect.Member; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Properties; @@ -37,6 +39,8 @@ import jakarta.persistence.PersistenceProperty; import jakarta.persistence.PersistenceUnit; import jakarta.persistence.SynchronizationType; +import org.springframework.aot.generator.CodeContribution; +import org.springframework.aot.generator.ProtectedAccess.Options; import org.springframework.beans.BeanUtils; import org.springframework.beans.PropertyValues; import org.springframework.beans.factory.BeanCreationException; @@ -45,17 +49,23 @@ import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.annotation.InjectionMetadata; +import org.springframework.beans.factory.annotation.InjectionMetadata.InjectedElement; import org.springframework.beans.factory.config.ConfigurableBeanFactory; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.DestructionAwareBeanPostProcessor; import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessor; import org.springframework.beans.factory.config.NamedBeanHolder; +import org.springframework.beans.factory.generator.AotContributingBeanPostProcessor; +import org.springframework.beans.factory.generator.BeanFieldGenerator; +import org.springframework.beans.factory.generator.BeanInstantiationContribution; import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.core.BridgeMethodResolver; import org.springframework.core.Ordered; import org.springframework.core.PriorityOrdered; import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.support.MultiStatement; import org.springframework.jndi.JndiLocatorDelegate; import org.springframework.jndi.JndiTemplate; import org.springframework.lang.Nullable; @@ -66,6 +76,7 @@ import org.springframework.orm.jpa.ExtendedEntityManagerCreator; import org.springframework.orm.jpa.SharedEntityManagerCreator; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; @@ -164,6 +175,7 @@ import org.springframework.util.StringUtils; * * @author Rod Johnson * @author Juergen Hoeller + * @author Stephane Nicoll * @since 2.0 * @see jakarta.persistence.PersistenceUnit * @see jakarta.persistence.PersistenceContext @@ -171,7 +183,8 @@ import org.springframework.util.StringUtils; @SuppressWarnings("serial") public class PersistenceAnnotationBeanPostProcessor implements InstantiationAwareBeanPostProcessor, DestructionAwareBeanPostProcessor, - MergedBeanDefinitionPostProcessor, PriorityOrdered, BeanFactoryAware, Serializable { + MergedBeanDefinitionPostProcessor, AotContributingBeanPostProcessor, + PriorityOrdered, BeanFactoryAware, Serializable { @Nullable private Object jndiEnvironment; @@ -332,8 +345,23 @@ public class PersistenceAnnotationBeanPostProcessor @Override public void postProcessMergedBeanDefinition(RootBeanDefinition beanDefinition, Class beanType, String beanName) { + findInjectionMetadata(beanDefinition, beanType, beanName); + } + + @Override + public BeanInstantiationContribution contribute(RootBeanDefinition beanDefinition, Class beanType, String beanName) { + InjectionMetadata metadata = findInjectionMetadata(beanDefinition, beanType, beanName); + Collection injectedElements = metadata.getInjectedElements(); + if (!CollectionUtils.isEmpty(injectedElements)) { + return new PersistenceAnnotationBeanInstantiationContribution(injectedElements); + } + return null; + } + + private InjectionMetadata findInjectionMetadata(RootBeanDefinition beanDefinition, Class beanType, String beanName) { InjectionMetadata metadata = findPersistenceMetadata(beanName, beanType, null); metadata.checkConfigMembers(beanDefinition); + return metadata; } @Override @@ -725,4 +753,66 @@ public class PersistenceAnnotationBeanPostProcessor } } + private static final class PersistenceAnnotationBeanInstantiationContribution implements BeanInstantiationContribution { + + private static final BeanFieldGenerator fieldGenerator = new BeanFieldGenerator(); + + private final Collection injectedElements; + + private PersistenceAnnotationBeanInstantiationContribution(Collection injectedElements) { + this.injectedElements = injectedElements.stream() + .filter(obj -> obj instanceof PersistenceElement) + .map(PersistenceElement.class::cast).toList(); + } + + @Override + public void applyTo(CodeContribution contribution) { + this.injectedElements.forEach(element -> { + Member member = element.getMember(); + analyzeMember(contribution, member); + injectElement(contribution, element); + }); + } + + private void analyzeMember(CodeContribution contribution, Member member) { + if (member instanceof Method) { + contribution.protectedAccess().analyze(member, Options.defaults().build()); + } + else if (member instanceof Field field) { + contribution.protectedAccess().analyze(member, BeanFieldGenerator.FIELD_OPTIONS); + if (Modifier.isPrivate(field.getModifiers())) { + contribution.runtimeHints().reflection().registerField(field); + } + } + } + + private void injectElement(CodeContribution contribution, PersistenceElement element) { + MultiStatement statements = contribution.statements(); + statements.addStatement("$T entityManagerFactory = $T.findEntityManagerFactory(beanFactory, $S)", + EntityManagerFactory.class, EntityManagerFactoryUtils.class, element.unitName); + boolean requireEntityManager = (element.type != null); + if (requireEntityManager) { + Properties persistenceProperties = element.properties; + boolean hasPersistenceProperties = persistenceProperties != null && !persistenceProperties.isEmpty(); + if (hasPersistenceProperties) { + statements.addStatement("$T persistenceProperties = new Properties()", Properties.class); + persistenceProperties.stringPropertyNames().stream().sorted(String::compareTo).forEach(propertyName -> + statements.addStatement("persistenceProperties.put($S, $S)", + propertyName, persistenceProperties.getProperty(propertyName))); + } + statements.addStatement("$T entityManager = $T.createSharedEntityManager(entityManagerFactory, $L, $L)", + EntityManager.class, SharedEntityManagerCreator.class, (hasPersistenceProperties) ? "persistenceProperties" : null, element.synchronizedWithTransaction); + } + Member member = element.getMember(); + CodeBlock value = (requireEntityManager) ? CodeBlock.of("entityManager") : CodeBlock.of("entityManagerFactory"); + if (member instanceof Field field) { + statements.add(fieldGenerator.generateSetValue("bean", field, value)); + } + else { + statements.addStatement("bean.$L($L)", member.getName(), value); + } + } + + } + } diff --git a/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorTests.java b/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorTests.java new file mode 100644 index 0000000000..03217c6305 --- /dev/null +++ b/spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorTests.java @@ -0,0 +1,225 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.orm.jpa.support; + +import java.util.function.Consumer; +import java.util.function.Supplier; + +import jakarta.persistence.EntityManager; +import jakarta.persistence.EntityManagerFactory; +import jakarta.persistence.PersistenceContext; +import jakarta.persistence.PersistenceProperty; +import jakarta.persistence.PersistenceUnit; +import org.junit.jupiter.api.Test; + +import org.springframework.aot.generator.CodeContribution; +import org.springframework.aot.generator.DefaultCodeContribution; +import org.springframework.aot.generator.DefaultGeneratedTypeContext; +import org.springframework.aot.generator.GeneratedType; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.TypeReference; +import org.springframework.aot.test.generator.compile.TestCompiler; +import org.springframework.aot.test.generator.file.SourceFile; +import org.springframework.aot.test.generator.file.SourceFiles; +import org.springframework.beans.factory.generator.BeanInstantiationContribution; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.generator.ApplicationContextAotGenerator; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.JavaFile; +import org.springframework.javapoet.support.CodeSnippet; +import org.springframework.lang.Nullable; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link PersistenceAnnotationBeanPostProcessor}. + * + * @author Stephane Nicoll + */ +class PersistenceAnnotationBeanPostProcessorTests { + + @Test + void contributeForPersistenceUnitOnPublicField() { + CodeContribution contribution = contribute(DefaultPersistenceUnitField.class); + assertThat(contribution).isNotNull(); + assertThat(contribution.runtimeHints().reflection().typeHints()).isEmpty(); + assertThat(CodeSnippet.process(contribution.statements().toCodeBlock())).isEqualTo(""" + EntityManagerFactory entityManagerFactory = EntityManagerFactoryUtils.findEntityManagerFactory(beanFactory, ""); + bean.emf = entityManagerFactory; + """); + } + + @Test + void contributeForPersistenceUnitOnPublicSetter() { + CodeContribution contribution = contribute(DefaultPersistenceUnitMethod.class); + assertThat(contribution).isNotNull(); + assertThat(contribution.runtimeHints().reflection().typeHints()).isEmpty(); + assertThat(CodeSnippet.process(contribution.statements().toCodeBlock())).isEqualTo(""" + EntityManagerFactory entityManagerFactory = EntityManagerFactoryUtils.findEntityManagerFactory(beanFactory, ""); + bean.setEmf(entityManagerFactory); + """); + } + + @Test + void contributeForPersistenceUnitWithCustomUnitOnPublicSetter() { + CodeContribution contribution = contribute(CustomUnitNamePublicPersistenceUnitMethod.class); + assertThat(contribution).isNotNull(); + assertThat(contribution.runtimeHints().reflection().typeHints()).isEmpty(); + assertThat(CodeSnippet.process(contribution.statements().toCodeBlock())).isEqualTo(""" + EntityManagerFactory entityManagerFactory = EntityManagerFactoryUtils.findEntityManagerFactory(beanFactory, "custom"); + bean.setEmf(entityManagerFactory); + """); + } + + @Test + void contributeForPersistenceContextOnPrivateField() { + CodeContribution contribution = contribute(DefaultPersistenceContextField.class); + assertThat(contribution).isNotNull(); + assertThat(contribution.runtimeHints().reflection().typeHints()).singleElement().satisfies(typeHint -> { + assertThat(typeHint.getType()).isEqualTo(TypeReference.of(DefaultPersistenceContextField.class)); + assertThat(typeHint.fields()).singleElement().satisfies(fieldHint -> { + assertThat(fieldHint.getName()).isEqualTo("entityManager"); + assertThat(fieldHint.isAllowWrite()).isTrue(); + assertThat(fieldHint.isAllowUnsafeAccess()).isFalse(); + }); + }); + assertThat(CodeSnippet.process(contribution.statements().toCodeBlock())).isEqualTo(""" + EntityManagerFactory entityManagerFactory = EntityManagerFactoryUtils.findEntityManagerFactory(beanFactory, ""); + EntityManager entityManager = SharedEntityManagerCreator.createSharedEntityManager(entityManagerFactory, null, true); + Field entityManagerField = ReflectionUtils.findField(PersistenceAnnotationBeanPostProcessorTests.DefaultPersistenceContextField.class, "entityManager"); + ReflectionUtils.makeAccessible(entityManagerField); + ReflectionUtils.setField(entityManagerField, bean, entityManager); + """); + } + + @Test + void contributeForPersistenceContextWithCustomPropertiesOnMethod() { + CodeContribution contribution = contribute(CustomPropertiesPersistenceContextMethod.class); + assertThat(contribution).isNotNull(); + assertThat(contribution.runtimeHints().reflection().typeHints()).isEmpty(); + assertThat(CodeSnippet.process(contribution.statements().toCodeBlock())).isEqualTo(""" + EntityManagerFactory entityManagerFactory = EntityManagerFactoryUtils.findEntityManagerFactory(beanFactory, ""); + Properties persistenceProperties = new Properties(); + persistenceProperties.put("jpa.test", "value"); + persistenceProperties.put("jpa.test2", "value2"); + EntityManager entityManager = SharedEntityManagerCreator.createSharedEntityManager(entityManagerFactory, persistenceProperties, true); + bean.setEntityManager(entityManager); + """); + } + + @Test + void generateEntityManagerFactoryInjection() { + GenericApplicationContext context = new AnnotationConfigApplicationContext(); + context.registerBeanDefinition("test", new RootBeanDefinition(DefaultPersistenceUnitField.class)); + + EntityManagerFactory entityManagerFactory = mock(EntityManagerFactory.class); + compile(context, toFreshApplicationContext(() -> { + GenericApplicationContext ctx = new GenericApplicationContext(); + ctx.getDefaultListableBeanFactory().registerSingleton("myEmf", entityManagerFactory); + return ctx; + }, aotContext -> assertThat(aotContext.getBean("test")).hasFieldOrPropertyWithValue("emf", entityManagerFactory))); + } + + private DefaultCodeContribution contribute(Class type) { + BeanInstantiationContribution contributor = createContribution(type); + assertThat(contributor).isNotNull(); + DefaultCodeContribution contribution = new DefaultCodeContribution(new RuntimeHints()); + contributor.applyTo(contribution); + return contribution; + } + + @Nullable + private BeanInstantiationContribution createContribution(Class beanType) { + PersistenceAnnotationBeanPostProcessor bpp = new PersistenceAnnotationBeanPostProcessor(); + RootBeanDefinition beanDefinition = new RootBeanDefinition(beanType); + return bpp.contribute(beanDefinition, beanType, "test"); + } + + @SuppressWarnings("rawtypes") + private void compile(GenericApplicationContext applicationContext, Consumer initializer) { + DefaultGeneratedTypeContext generationContext = new DefaultGeneratedTypeContext("com.example", + packageName -> GeneratedType.of(ClassName.get(packageName, "Test"))); + ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator(); + generator.generateApplicationContext(applicationContext, generationContext); + SourceFiles sourceFiles = SourceFiles.none(); + for (JavaFile javaFile : generationContext.toJavaFiles()) { + sourceFiles = sourceFiles.and(SourceFile.of((javaFile::writeTo))); + } + TestCompiler.forSystem().withSources(sourceFiles).compile(compiled -> { + ApplicationContextInitializer instance = compiled.getInstance(ApplicationContextInitializer.class, "com.example.Test"); + initializer.accept(instance); + }); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private Consumer toFreshApplicationContext( + Supplier applicationContextFactory, Consumer context) { + return applicationContextInitializer -> { + T applicationContext = applicationContextFactory.get(); + applicationContextInitializer.initialize(applicationContext); + applicationContext.refresh(); + context.accept(applicationContext); + }; + } + + + static class DefaultPersistenceUnitField { + + @PersistenceUnit + public EntityManagerFactory emf; + + } + + static class DefaultPersistenceUnitMethod { + + @PersistenceUnit + public void setEmf(EntityManagerFactory emf) { + } + + } + + static class CustomUnitNamePublicPersistenceUnitMethod { + + @PersistenceUnit(unitName = "custom") + public void setEmf(EntityManagerFactory emf) { + } + + } + + static class DefaultPersistenceContextField { + + @PersistenceContext + private EntityManager entityManager; + + } + + static class CustomPropertiesPersistenceContextMethod { + + @PersistenceContext(properties = { + @PersistenceProperty(name = "jpa.test", value = "value"), + @PersistenceProperty(name = "jpa.test2", value = "value2") }) + public void setEntityManager(EntityManager entityManager) { + + } + + } + +}