Browse Source

AOT contribution for @PersistenceContext and @PersistenceUnit

Closes gh-28364
pull/28395/head
Stephane Nicoll 3 years ago
parent
commit
26054fd3d4
  1. 65
      spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFieldGenerator.java
  2. 27
      spring-beans/src/main/java/org/springframework/beans/factory/generator/InjectionGenerator.java
  3. 78
      spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanFieldGeneratorTests.java
  4. 10
      spring-core/src/main/java/org/springframework/javapoet/support/MultiStatement.java
  5. 13
      spring-core/src/test/java/org/springframework/javapoet/support/MultiStatementTests.java
  6. 1
      spring-orm/spring-orm.gradle
  7. 94
      spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java
  8. 225
      spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorTests.java

65
spring-beans/src/main/java/org/springframework/beans/factory/generator/BeanFieldGenerator.java

@ -0,0 +1,65 @@ @@ -0,0 +1,65 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.beans.factory.generator;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import org.springframework.aot.generator.ProtectedAccess.Options;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.support.MultiStatement;
import org.springframework.util.ReflectionUtils;
/**
* Support for generating {@link Field} access.
*
* @author Stephane Nicoll
* @since 6.0
*/
public class BeanFieldGenerator {
/**
* The {@link Options} to use to access a field.
*/
public static final Options FIELD_OPTIONS = Options.defaults()
.useReflection(member -> Modifier.isPrivate(member.getModifiers())).build();
/**
* Generate the necessary code to set the specified field. Use reflection
* using {@link ReflectionUtils} if necessary.
* @param field the field to set
* @param value a code representation of the field value
* @return the code to set the specified field
*/
public MultiStatement generateSetValue(String target, Field field, CodeBlock value) {
MultiStatement statement = new MultiStatement();
boolean useReflection = Modifier.isPrivate(field.getModifiers());
if (useReflection) {
String fieldName = String.format("%sField", field.getName());
statement.addStatement("$T $L = $T.findField($T.class, $S)", Field.class, fieldName, ReflectionUtils.class,
field.getDeclaringClass(), field.getName());
statement.addStatement("$T.makeAccessible($L)", ReflectionUtils.class, fieldName);
statement.addStatement("$T.setField($L, $L, $L)", ReflectionUtils.class, fieldName, target, value);
}
else {
statement.addStatement("$L.$L = $L", target, field.getName(), value);
}
return statement;
}
}

27
spring-beans/src/main/java/org/springframework/beans/factory/generator/InjectionGenerator.java

@ -34,7 +34,6 @@ import org.springframework.beans.factory.generator.config.BeanDefinitionRegistra @@ -34,7 +34,6 @@ import org.springframework.beans.factory.generator.config.BeanDefinitionRegistra
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.CodeBlock.Builder;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
/**
* Generate the necessary code to {@link #generateInstantiation(Executable)
@ -53,14 +52,13 @@ import org.springframework.util.ReflectionUtils; @@ -53,14 +52,13 @@ import org.springframework.util.ReflectionUtils;
*/
public class InjectionGenerator {
private static final Options FIELD_INJECTION_OPTIONS = Options.defaults()
.useReflection(member -> Modifier.isPrivate(member.getModifiers())).build();
private static final Options METHOD_INJECTION_OPTIONS = Options.defaults()
.useReflection(member -> false).build();
private final BeanParameterGenerator parameterGenerator = new BeanParameterGenerator();
private final BeanFieldGenerator fieldGenerator = new BeanFieldGenerator();
/**
* Generate the necessary code to instantiate an object using the specified
@ -110,7 +108,7 @@ public class InjectionGenerator { @@ -110,7 +108,7 @@ public class InjectionGenerator {
return METHOD_INJECTION_OPTIONS;
}
if (member instanceof Field) {
return FIELD_INJECTION_OPTIONS;
return BeanFieldGenerator.FIELD_OPTIONS;
}
throw new IllegalArgumentException("Could not handle member " + member);
}
@ -230,24 +228,13 @@ public class InjectionGenerator { @@ -230,24 +228,13 @@ public class InjectionGenerator {
code.add("instanceContext.field($S", injectionPoint.getName());
code.add(")\n").indent().indent();
if (required) {
code.add(".invoke(beanFactory, (attributes) ->");
}
else {
code.add(".resolve(beanFactory, false).ifResolved((attributes) ->");
}
boolean hasAssignment = Modifier.isPrivate(injectionPoint.getModifiers());
if (hasAssignment) {
code.beginControlFlow("");
String fieldName = String.format("%sField", injectionPoint.getName());
code.addStatement("$T $L = $T.findField($T.class, $S)", Field.class, fieldName, ReflectionUtils.class,
injectionPoint.getDeclaringClass(), injectionPoint.getName());
code.addStatement("$T.makeAccessible($L)", ReflectionUtils.class, fieldName);
code.addStatement("$T.setField($L, bean, attributes.get(0))", ReflectionUtils.class, fieldName);
code.unindent().add("}");
code.add(".invoke(beanFactory, ");
}
else {
code.add(" bean.$L = attributes.get(0)", injectionPoint.getName());
code.add(".resolve(beanFactory, false).ifResolved(");
}
code.add(this.fieldGenerator.generateSetValue("bean", injectionPoint,
CodeBlock.of("attributes.get(0)")).toLambdaBody("(attributes) ->"));
code.add(")").unindent().unindent();
return code.build();
}

78
spring-beans/src/test/java/org/springframework/beans/factory/generator/BeanFieldGeneratorTests.java

@ -0,0 +1,78 @@ @@ -0,0 +1,78 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.beans.factory.generator;
import java.lang.reflect.Field;
import org.junit.jupiter.api.Test;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.support.CodeSnippet;
import org.springframework.javapoet.support.MultiStatement;
import org.springframework.util.ReflectionUtils;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link BeanFieldGenerator}.
*
* @author Stephane Nicoll
*/
class BeanFieldGeneratorTests {
private final BeanFieldGenerator generator = new BeanFieldGenerator();
@Test
void generateSetFieldWithPublicField() {
MultiStatement statement = this.generator.generateSetValue("bean",
field(SampleBean.class, "one"), CodeBlock.of("$S", "test"));
assertThat(CodeSnippet.process(statement.toCodeBlock())).isEqualTo("""
bean.one = "test";
""");
}
@Test
void generateSetFieldWithPrivateField() {
MultiStatement statement = this.generator.generateSetValue("example",
field(SampleBean.class, "two"), CodeBlock.of("42"));
CodeSnippet code = CodeSnippet.of(statement.toCodeBlock());
assertThat(code.getSnippet()).isEqualTo("""
Field twoField = ReflectionUtils.findField(BeanFieldGeneratorTests.SampleBean.class, "two");
ReflectionUtils.makeAccessible(twoField);
ReflectionUtils.setField(twoField, example, 42);
""");
assertThat(code.hasImport(ReflectionUtils.class)).isTrue();
assertThat(code.hasImport(BeanFieldGeneratorTests.class)).isTrue();
}
private Field field(Class<?> type, String name) {
Field field = ReflectionUtils.findField(type, name);
assertThat(field).isNotNull();
return field;
}
public static class SampleBean {
public String one;
private int two;
}
}

10
spring-core/src/main/java/org/springframework/javapoet/support/MultiStatement.java

@ -44,6 +44,16 @@ public final class MultiStatement { @@ -44,6 +44,16 @@ public final class MultiStatement {
return this.statements.isEmpty();
}
/**
* Add the statements defined in the specified multi statement to this instance.
* @param multiStatement the statements to add
* @return {@code this}, to facilitate method chaining
*/
public MultiStatement add(MultiStatement multiStatement) {
this.statements.addAll(multiStatement.statements);
return this;
}
/**
* Add the specified {@link CodeBlock codeblock} rendered as-is.
* @param codeBlock the code block to add

13
spring-core/src/test/java/org/springframework/javapoet/support/MultiStatementTests.java

@ -150,4 +150,17 @@ class MultiStatementTests { @@ -150,4 +150,17 @@ class MultiStatementTests {
"}");
}
@Test
void addWithAnotherMultiStatement() {
MultiStatement statements = new MultiStatement();
statements.addStatement(CodeBlock.of("test.invoke()"));
MultiStatement another = new MultiStatement();
another.addStatement(CodeBlock.of("test.another()"));
statements.add(another);
assertThat(statements.toCodeBlock().toString()).isEqualTo("""
test.invoke();
test.another();
""");
}
}

1
spring-orm/spring-orm.gradle

@ -11,6 +11,7 @@ dependencies { @@ -11,6 +11,7 @@ dependencies {
optional("org.eclipse.persistence:org.eclipse.persistence.jpa")
optional("org.hibernate:hibernate-core-jakarta")
optional("jakarta.servlet:jakarta.servlet-api")
testImplementation(project(":spring-core-test"))
testImplementation(testFixtures(project(":spring-beans")))
testImplementation(testFixtures(project(":spring-context")))
testImplementation(testFixtures(project(":spring-core")))

94
spring-orm/src/main/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessor.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,11 +19,13 @@ package org.springframework.orm.jpa.support; @@ -19,11 +19,13 @@ package org.springframework.orm.jpa.support;
import java.beans.PropertyDescriptor;
import java.io.Serializable;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Field;
import java.lang.reflect.Member;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Properties;
@ -37,6 +39,8 @@ import jakarta.persistence.PersistenceProperty; @@ -37,6 +39,8 @@ import jakarta.persistence.PersistenceProperty;
import jakarta.persistence.PersistenceUnit;
import jakarta.persistence.SynchronizationType;
import org.springframework.aot.generator.CodeContribution;
import org.springframework.aot.generator.ProtectedAccess.Options;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.PropertyValues;
import org.springframework.beans.factory.BeanCreationException;
@ -45,17 +49,23 @@ import org.springframework.beans.factory.BeanFactoryAware; @@ -45,17 +49,23 @@ import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.annotation.InjectionMetadata;
import org.springframework.beans.factory.annotation.InjectionMetadata.InjectedElement;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.DestructionAwareBeanPostProcessor;
import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessor;
import org.springframework.beans.factory.config.NamedBeanHolder;
import org.springframework.beans.factory.generator.AotContributingBeanPostProcessor;
import org.springframework.beans.factory.generator.BeanFieldGenerator;
import org.springframework.beans.factory.generator.BeanInstantiationContribution;
import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.BridgeMethodResolver;
import org.springframework.core.Ordered;
import org.springframework.core.PriorityOrdered;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.support.MultiStatement;
import org.springframework.jndi.JndiLocatorDelegate;
import org.springframework.jndi.JndiTemplate;
import org.springframework.lang.Nullable;
@ -66,6 +76,7 @@ import org.springframework.orm.jpa.ExtendedEntityManagerCreator; @@ -66,6 +76,7 @@ import org.springframework.orm.jpa.ExtendedEntityManagerCreator;
import org.springframework.orm.jpa.SharedEntityManagerCreator;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
@ -164,6 +175,7 @@ import org.springframework.util.StringUtils; @@ -164,6 +175,7 @@ import org.springframework.util.StringUtils;
*
* @author Rod Johnson
* @author Juergen Hoeller
* @author Stephane Nicoll
* @since 2.0
* @see jakarta.persistence.PersistenceUnit
* @see jakarta.persistence.PersistenceContext
@ -171,7 +183,8 @@ import org.springframework.util.StringUtils; @@ -171,7 +183,8 @@ import org.springframework.util.StringUtils;
@SuppressWarnings("serial")
public class PersistenceAnnotationBeanPostProcessor
implements InstantiationAwareBeanPostProcessor, DestructionAwareBeanPostProcessor,
MergedBeanDefinitionPostProcessor, PriorityOrdered, BeanFactoryAware, Serializable {
MergedBeanDefinitionPostProcessor, AotContributingBeanPostProcessor,
PriorityOrdered, BeanFactoryAware, Serializable {
@Nullable
private Object jndiEnvironment;
@ -332,8 +345,23 @@ public class PersistenceAnnotationBeanPostProcessor @@ -332,8 +345,23 @@ public class PersistenceAnnotationBeanPostProcessor
@Override
public void postProcessMergedBeanDefinition(RootBeanDefinition beanDefinition, Class<?> beanType, String beanName) {
findInjectionMetadata(beanDefinition, beanType, beanName);
}
@Override
public BeanInstantiationContribution contribute(RootBeanDefinition beanDefinition, Class<?> beanType, String beanName) {
InjectionMetadata metadata = findInjectionMetadata(beanDefinition, beanType, beanName);
Collection<InjectedElement> injectedElements = metadata.getInjectedElements();
if (!CollectionUtils.isEmpty(injectedElements)) {
return new PersistenceAnnotationBeanInstantiationContribution(injectedElements);
}
return null;
}
private InjectionMetadata findInjectionMetadata(RootBeanDefinition beanDefinition, Class<?> beanType, String beanName) {
InjectionMetadata metadata = findPersistenceMetadata(beanName, beanType, null);
metadata.checkConfigMembers(beanDefinition);
return metadata;
}
@Override
@ -725,4 +753,66 @@ public class PersistenceAnnotationBeanPostProcessor @@ -725,4 +753,66 @@ public class PersistenceAnnotationBeanPostProcessor
}
}
private static final class PersistenceAnnotationBeanInstantiationContribution implements BeanInstantiationContribution {
private static final BeanFieldGenerator fieldGenerator = new BeanFieldGenerator();
private final Collection<PersistenceElement> injectedElements;
private PersistenceAnnotationBeanInstantiationContribution(Collection<InjectedElement> injectedElements) {
this.injectedElements = injectedElements.stream()
.filter(obj -> obj instanceof PersistenceElement)
.map(PersistenceElement.class::cast).toList();
}
@Override
public void applyTo(CodeContribution contribution) {
this.injectedElements.forEach(element -> {
Member member = element.getMember();
analyzeMember(contribution, member);
injectElement(contribution, element);
});
}
private void analyzeMember(CodeContribution contribution, Member member) {
if (member instanceof Method) {
contribution.protectedAccess().analyze(member, Options.defaults().build());
}
else if (member instanceof Field field) {
contribution.protectedAccess().analyze(member, BeanFieldGenerator.FIELD_OPTIONS);
if (Modifier.isPrivate(field.getModifiers())) {
contribution.runtimeHints().reflection().registerField(field);
}
}
}
private void injectElement(CodeContribution contribution, PersistenceElement element) {
MultiStatement statements = contribution.statements();
statements.addStatement("$T entityManagerFactory = $T.findEntityManagerFactory(beanFactory, $S)",
EntityManagerFactory.class, EntityManagerFactoryUtils.class, element.unitName);
boolean requireEntityManager = (element.type != null);
if (requireEntityManager) {
Properties persistenceProperties = element.properties;
boolean hasPersistenceProperties = persistenceProperties != null && !persistenceProperties.isEmpty();
if (hasPersistenceProperties) {
statements.addStatement("$T persistenceProperties = new Properties()", Properties.class);
persistenceProperties.stringPropertyNames().stream().sorted(String::compareTo).forEach(propertyName ->
statements.addStatement("persistenceProperties.put($S, $S)",
propertyName, persistenceProperties.getProperty(propertyName)));
}
statements.addStatement("$T entityManager = $T.createSharedEntityManager(entityManagerFactory, $L, $L)",
EntityManager.class, SharedEntityManagerCreator.class, (hasPersistenceProperties) ? "persistenceProperties" : null, element.synchronizedWithTransaction);
}
Member member = element.getMember();
CodeBlock value = (requireEntityManager) ? CodeBlock.of("entityManager") : CodeBlock.of("entityManagerFactory");
if (member instanceof Field field) {
statements.add(fieldGenerator.generateSetValue("bean", field, value));
}
else {
statements.addStatement("bean.$L($L)", member.getName(), value);
}
}
}
}

225
spring-orm/src/test/java/org/springframework/orm/jpa/support/PersistenceAnnotationBeanPostProcessorTests.java

@ -0,0 +1,225 @@ @@ -0,0 +1,225 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.orm.jpa.support;
import java.util.function.Consumer;
import java.util.function.Supplier;
import jakarta.persistence.EntityManager;
import jakarta.persistence.EntityManagerFactory;
import jakarta.persistence.PersistenceContext;
import jakarta.persistence.PersistenceProperty;
import jakarta.persistence.PersistenceUnit;
import org.junit.jupiter.api.Test;
import org.springframework.aot.generator.CodeContribution;
import org.springframework.aot.generator.DefaultCodeContribution;
import org.springframework.aot.generator.DefaultGeneratedTypeContext;
import org.springframework.aot.generator.GeneratedType;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;
import org.springframework.aot.test.generator.compile.TestCompiler;
import org.springframework.aot.test.generator.file.SourceFile;
import org.springframework.aot.test.generator.file.SourceFiles;
import org.springframework.beans.factory.generator.BeanInstantiationContribution;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.generator.ApplicationContextAotGenerator;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.JavaFile;
import org.springframework.javapoet.support.CodeSnippet;
import org.springframework.lang.Nullable;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link PersistenceAnnotationBeanPostProcessor}.
*
* @author Stephane Nicoll
*/
class PersistenceAnnotationBeanPostProcessorTests {
@Test
void contributeForPersistenceUnitOnPublicField() {
CodeContribution contribution = contribute(DefaultPersistenceUnitField.class);
assertThat(contribution).isNotNull();
assertThat(contribution.runtimeHints().reflection().typeHints()).isEmpty();
assertThat(CodeSnippet.process(contribution.statements().toCodeBlock())).isEqualTo("""
EntityManagerFactory entityManagerFactory = EntityManagerFactoryUtils.findEntityManagerFactory(beanFactory, "");
bean.emf = entityManagerFactory;
""");
}
@Test
void contributeForPersistenceUnitOnPublicSetter() {
CodeContribution contribution = contribute(DefaultPersistenceUnitMethod.class);
assertThat(contribution).isNotNull();
assertThat(contribution.runtimeHints().reflection().typeHints()).isEmpty();
assertThat(CodeSnippet.process(contribution.statements().toCodeBlock())).isEqualTo("""
EntityManagerFactory entityManagerFactory = EntityManagerFactoryUtils.findEntityManagerFactory(beanFactory, "");
bean.setEmf(entityManagerFactory);
""");
}
@Test
void contributeForPersistenceUnitWithCustomUnitOnPublicSetter() {
CodeContribution contribution = contribute(CustomUnitNamePublicPersistenceUnitMethod.class);
assertThat(contribution).isNotNull();
assertThat(contribution.runtimeHints().reflection().typeHints()).isEmpty();
assertThat(CodeSnippet.process(contribution.statements().toCodeBlock())).isEqualTo("""
EntityManagerFactory entityManagerFactory = EntityManagerFactoryUtils.findEntityManagerFactory(beanFactory, "custom");
bean.setEmf(entityManagerFactory);
""");
}
@Test
void contributeForPersistenceContextOnPrivateField() {
CodeContribution contribution = contribute(DefaultPersistenceContextField.class);
assertThat(contribution).isNotNull();
assertThat(contribution.runtimeHints().reflection().typeHints()).singleElement().satisfies(typeHint -> {
assertThat(typeHint.getType()).isEqualTo(TypeReference.of(DefaultPersistenceContextField.class));
assertThat(typeHint.fields()).singleElement().satisfies(fieldHint -> {
assertThat(fieldHint.getName()).isEqualTo("entityManager");
assertThat(fieldHint.isAllowWrite()).isTrue();
assertThat(fieldHint.isAllowUnsafeAccess()).isFalse();
});
});
assertThat(CodeSnippet.process(contribution.statements().toCodeBlock())).isEqualTo("""
EntityManagerFactory entityManagerFactory = EntityManagerFactoryUtils.findEntityManagerFactory(beanFactory, "");
EntityManager entityManager = SharedEntityManagerCreator.createSharedEntityManager(entityManagerFactory, null, true);
Field entityManagerField = ReflectionUtils.findField(PersistenceAnnotationBeanPostProcessorTests.DefaultPersistenceContextField.class, "entityManager");
ReflectionUtils.makeAccessible(entityManagerField);
ReflectionUtils.setField(entityManagerField, bean, entityManager);
""");
}
@Test
void contributeForPersistenceContextWithCustomPropertiesOnMethod() {
CodeContribution contribution = contribute(CustomPropertiesPersistenceContextMethod.class);
assertThat(contribution).isNotNull();
assertThat(contribution.runtimeHints().reflection().typeHints()).isEmpty();
assertThat(CodeSnippet.process(contribution.statements().toCodeBlock())).isEqualTo("""
EntityManagerFactory entityManagerFactory = EntityManagerFactoryUtils.findEntityManagerFactory(beanFactory, "");
Properties persistenceProperties = new Properties();
persistenceProperties.put("jpa.test", "value");
persistenceProperties.put("jpa.test2", "value2");
EntityManager entityManager = SharedEntityManagerCreator.createSharedEntityManager(entityManagerFactory, persistenceProperties, true);
bean.setEntityManager(entityManager);
""");
}
@Test
void generateEntityManagerFactoryInjection() {
GenericApplicationContext context = new AnnotationConfigApplicationContext();
context.registerBeanDefinition("test", new RootBeanDefinition(DefaultPersistenceUnitField.class));
EntityManagerFactory entityManagerFactory = mock(EntityManagerFactory.class);
compile(context, toFreshApplicationContext(() -> {
GenericApplicationContext ctx = new GenericApplicationContext();
ctx.getDefaultListableBeanFactory().registerSingleton("myEmf", entityManagerFactory);
return ctx;
}, aotContext -> assertThat(aotContext.getBean("test")).hasFieldOrPropertyWithValue("emf", entityManagerFactory)));
}
private DefaultCodeContribution contribute(Class<?> type) {
BeanInstantiationContribution contributor = createContribution(type);
assertThat(contributor).isNotNull();
DefaultCodeContribution contribution = new DefaultCodeContribution(new RuntimeHints());
contributor.applyTo(contribution);
return contribution;
}
@Nullable
private BeanInstantiationContribution createContribution(Class<?> beanType) {
PersistenceAnnotationBeanPostProcessor bpp = new PersistenceAnnotationBeanPostProcessor();
RootBeanDefinition beanDefinition = new RootBeanDefinition(beanType);
return bpp.contribute(beanDefinition, beanType, "test");
}
@SuppressWarnings("rawtypes")
private void compile(GenericApplicationContext applicationContext, Consumer<ApplicationContextInitializer> initializer) {
DefaultGeneratedTypeContext generationContext = new DefaultGeneratedTypeContext("com.example",
packageName -> GeneratedType.of(ClassName.get(packageName, "Test")));
ApplicationContextAotGenerator generator = new ApplicationContextAotGenerator();
generator.generateApplicationContext(applicationContext, generationContext);
SourceFiles sourceFiles = SourceFiles.none();
for (JavaFile javaFile : generationContext.toJavaFiles()) {
sourceFiles = sourceFiles.and(SourceFile.of((javaFile::writeTo)));
}
TestCompiler.forSystem().withSources(sourceFiles).compile(compiled -> {
ApplicationContextInitializer instance = compiled.getInstance(ApplicationContextInitializer.class, "com.example.Test");
initializer.accept(instance);
});
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private <T extends GenericApplicationContext> Consumer<ApplicationContextInitializer> toFreshApplicationContext(
Supplier<T> applicationContextFactory, Consumer<T> context) {
return applicationContextInitializer -> {
T applicationContext = applicationContextFactory.get();
applicationContextInitializer.initialize(applicationContext);
applicationContext.refresh();
context.accept(applicationContext);
};
}
static class DefaultPersistenceUnitField {
@PersistenceUnit
public EntityManagerFactory emf;
}
static class DefaultPersistenceUnitMethod {
@PersistenceUnit
public void setEmf(EntityManagerFactory emf) {
}
}
static class CustomUnitNamePublicPersistenceUnitMethod {
@PersistenceUnit(unitName = "custom")
public void setEmf(EntityManagerFactory emf) {
}
}
static class DefaultPersistenceContextField {
@PersistenceContext
private EntityManager entityManager;
}
static class CustomPropertiesPersistenceContextMethod {
@PersistenceContext(properties = {
@PersistenceProperty(name = "jpa.test", value = "value"),
@PersistenceProperty(name = "jpa.test2", value = "value2") })
public void setEntityManager(EntityManager entityManager) {
}
}
}
Loading…
Cancel
Save