Browse Source

Infer JDK dynamic proxies for Spring beans

See gh-28980
pull/28958/head
Sébastien Deleuze 2 years ago
parent
commit
ebf6de8f5d
  1. 2
      spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java
  2. 12
      spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java
  3. 7
      spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java
  4. 60
      spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java

2
spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java

@ -51,7 +51,7 @@ public class ApplicationContextAotGenerator { @@ -51,7 +51,7 @@ public class ApplicationContextAotGenerator {
public ClassName processAheadOfTime(GenericApplicationContext applicationContext,
GenerationContext generationContext) {
return withGeneratedClassHandler(new GeneratedClassHandler(generationContext), () -> {
applicationContext.refreshForAotProcessing();
applicationContext.refreshForAotProcessing(generationContext.getRuntimeHints());
DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory();
ApplicationContextInitializationCodeGenerator codeGenerator =
new ApplicationContextInitializationCodeGenerator(generationContext);

12
spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java

@ -18,10 +18,12 @@ package org.springframework.context.support; @@ -18,10 +18,12 @@ package org.springframework.context.support;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Proxy;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanDefinitionStoreException;
@ -394,12 +396,13 @@ public class GenericApplicationContext extends AbstractApplicationContext implem @@ -394,12 +396,13 @@ public class GenericApplicationContext extends AbstractApplicationContext implem
* processing that optimizes the application context, typically at build time.
* <p>In this mode, only {@link BeanDefinitionRegistryPostProcessor} and
* {@link MergedBeanDefinitionPostProcessor} are invoked.
* @param runtimeHints the runtime hints
* @throws BeansException if the bean factory could not be initialized
* @throws IllegalStateException if already initialized and multiple refresh
* attempts are not supported
* @since 6.0
*/
public void refreshForAotProcessing() {
public void refreshForAotProcessing(RuntimeHints runtimeHints) {
if (logger.isDebugEnabled()) {
logger.debug("Preparing bean factory for AOT processing");
}
@ -410,7 +413,7 @@ public class GenericApplicationContext extends AbstractApplicationContext implem @@ -410,7 +413,7 @@ public class GenericApplicationContext extends AbstractApplicationContext implem
invokeBeanFactoryPostProcessors(this.beanFactory);
this.beanFactory.freezeConfiguration();
PostProcessorRegistrationDelegate.invokeMergedBeanDefinitionPostProcessors(this.beanFactory);
preDetermineBeanTypes();
preDetermineBeanTypes(runtimeHints);
}
/**
@ -418,7 +421,7 @@ public class GenericApplicationContext extends AbstractApplicationContext implem @@ -418,7 +421,7 @@ public class GenericApplicationContext extends AbstractApplicationContext implem
* @see org.springframework.beans.factory.BeanFactory#getType
* @see SmartInstantiationAwareBeanPostProcessor#determineBeanType
*/
private void preDetermineBeanTypes() {
private void preDetermineBeanTypes(RuntimeHints runtimeHints) {
List<SmartInstantiationAwareBeanPostProcessor> bpps =
PostProcessorRegistrationDelegate.loadBeanPostProcessors(
this.beanFactory, SmartInstantiationAwareBeanPostProcessor.class);
@ -427,6 +430,9 @@ public class GenericApplicationContext extends AbstractApplicationContext implem @@ -427,6 +430,9 @@ public class GenericApplicationContext extends AbstractApplicationContext implem
if (beanType != null) {
for (SmartInstantiationAwareBeanPostProcessor bpp : bpps) {
beanType = bpp.determineBeanType(beanType, beanName);
if (Proxy.isProxyClass(beanType)) {
runtimeHints.proxies().registerJdkProxy(beanType.getInterfaces());
}
}
}
}

7
spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java

@ -21,6 +21,7 @@ import java.util.regex.Pattern; @@ -21,6 +21,7 @@ import java.util.regex.Pattern;
import org.junit.jupiter.api.Test;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.annotation.Autowired;
@ -426,7 +427,7 @@ class AnnotationConfigApplicationContextTests { @@ -426,7 +427,7 @@ class AnnotationConfigApplicationContextTests {
void refreshForAotProcessingWithConfiguration() {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
context.register(Config.class);
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
assertThat(context.getBeanFactory().getBeanDefinitionNames()).contains(
"annotationConfigApplicationContextTests.Config", "testBean");
}
@ -435,7 +436,7 @@ class AnnotationConfigApplicationContextTests { @@ -435,7 +436,7 @@ class AnnotationConfigApplicationContextTests {
void refreshForAotCanInstantiateBeanWithAutowiredApplicationContext() {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
context.register(BeanD.class);
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
BeanD bean = context.getBean(BeanD.class);
assertThat(bean.applicationContext).isSameAs(context);
}
@ -444,7 +445,7 @@ class AnnotationConfigApplicationContextTests { @@ -444,7 +445,7 @@ class AnnotationConfigApplicationContextTests {
void refreshForAotCanInstantiateBeanWithFieldAutowiredApplicationContext() {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
context.register(BeanB.class);
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
BeanB bean = context.getBean(BeanB.class);
assertThat(bean.applicationContext).isSameAs(context);
}

60
spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java

@ -16,18 +16,24 @@ @@ -16,18 +16,24 @@
package org.springframework.context.support;
import java.lang.reflect.Proxy;
import java.nio.file.InvalidPathException;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.OS;
import org.mockito.ArgumentCaptor;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.predicate.RuntimeHintsPredicates;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.config.AbstractFactoryBean;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.SmartInstantiationAwareBeanPostProcessor;
import org.springframework.beans.factory.support.AbstractBeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.GenericBeanDefinition;
@ -35,6 +41,7 @@ import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcess @@ -35,6 +41,7 @@ import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcess
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.DecoratingProxy;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.Environment;
import org.springframework.core.io.ByteArrayResource;
@ -296,7 +303,7 @@ class GenericApplicationContextTests { @@ -296,7 +303,7 @@ class GenericApplicationContextTests {
void refreshForAotSetsContextActive() {
GenericApplicationContext context = new GenericApplicationContext();
assertThat(context.isActive()).isFalse();
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
assertThat(context.isActive()).isTrue();
context.close();
}
@ -306,7 +313,7 @@ class GenericApplicationContextTests { @@ -306,7 +313,7 @@ class GenericApplicationContextTests {
ConfigurableEnvironment environment = mock(ConfigurableEnvironment.class);
GenericApplicationContext context = new GenericApplicationContext();
context.setEnvironment(environment);
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
assertThat(context.getBean(Environment.class)).isEqualTo(environment);
context.close();
}
@ -315,7 +322,7 @@ class GenericApplicationContextTests { @@ -315,7 +322,7 @@ class GenericApplicationContextTests {
void refreshForAotLoadsBeanClassName() {
GenericApplicationContext context = new GenericApplicationContext();
context.registerBeanDefinition("number", new RootBeanDefinition("java.lang.Integer"));
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
assertThat(getBeanDefinition(context, "number").getBeanClass()).isEqualTo(Integer.class);
context.close();
}
@ -328,7 +335,7 @@ class GenericApplicationContextTests { @@ -328,7 +335,7 @@ class GenericApplicationContextTests {
innerBeanDefinition.setBeanClassName("java.lang.Integer");
beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, innerBeanDefinition);
context.registerBeanDefinition("test",beanDefinition);
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
RootBeanDefinition bd = getBeanDefinition(context, "test");
GenericBeanDefinition value = (GenericBeanDefinition) bd.getConstructorArgumentValues()
.getIndexedArgumentValue(0, GenericBeanDefinition.class).getValue();
@ -345,7 +352,7 @@ class GenericApplicationContextTests { @@ -345,7 +352,7 @@ class GenericApplicationContextTests {
innerBeanDefinition.setBeanClassName("java.lang.Integer");
beanDefinition.getPropertyValues().add("inner", innerBeanDefinition);
context.registerBeanDefinition("test",beanDefinition);
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
RootBeanDefinition bd = getBeanDefinition(context, "test");
GenericBeanDefinition value = (GenericBeanDefinition) bd.getPropertyValues().get("inner");
assertThat(value.hasBeanClass()).isTrue();
@ -358,7 +365,7 @@ class GenericApplicationContextTests { @@ -358,7 +365,7 @@ class GenericApplicationContextTests {
GenericApplicationContext context = new GenericApplicationContext();
BeanFactoryPostProcessor bfpp = mock(BeanFactoryPostProcessor.class);
context.addBeanFactoryPostProcessor(bfpp);
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
verify(bfpp).postProcessBeanFactory(context.getBeanFactory());
context.close();
}
@ -369,7 +376,7 @@ class GenericApplicationContextTests { @@ -369,7 +376,7 @@ class GenericApplicationContextTests {
context.registerBeanDefinition("test", new RootBeanDefinition(String.class));
context.registerBeanDefinition("number", new RootBeanDefinition("java.lang.Integer"));
MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context);
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), String.class, "test");
verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "number"), Integer.class, "number");
context.close();
@ -384,7 +391,7 @@ class GenericApplicationContextTests { @@ -384,7 +391,7 @@ class GenericApplicationContextTests {
beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, innerBeanDefinition);
context.registerBeanDefinition("test", beanDefinition);
MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context);
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), BeanD.class, "test");
verify(bpp).postProcessMergedBeanDefinition(any(RootBeanDefinition.class), eq(Integer.class), captor.capture());
@ -401,7 +408,7 @@ class GenericApplicationContextTests { @@ -401,7 +408,7 @@ class GenericApplicationContextTests {
beanDefinition.getPropertyValues().add("counter", innerBeanDefinition);
context.registerBeanDefinition("test", beanDefinition);
MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context);
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), BeanD.class, "test");
verify(bpp).postProcessMergedBeanDefinition(any(RootBeanDefinition.class), eq(Integer.class), captor.capture());
@ -414,7 +421,7 @@ class GenericApplicationContextTests { @@ -414,7 +421,7 @@ class GenericApplicationContextTests {
GenericApplicationContext context = new GenericApplicationContext();
context.registerBeanDefinition("test", new RootBeanDefinition(String.class));
MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context);
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
RootBeanDefinition mergedBeanDefinition = getBeanDefinition(context, "test");
verify(bpp).postProcessMergedBeanDefinition(mergedBeanDefinition, String.class, "test");
context.getBeanFactory().clearMetadataCache();
@ -442,7 +449,7 @@ class GenericApplicationContextTests { @@ -442,7 +449,7 @@ class GenericApplicationContextTests {
AbstractBeanDefinition bd = BeanDefinitionBuilder.rootBeanDefinition(String.class)
.addConstructorArgValue("value").getBeanDefinition();
context.registerBeanDefinition("test", bd);
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
assertThat(context.getBeanFactory().getMergedBeanDefinition("test")
.hasAttribute("mbdppCalled")).isTrue();
assertThat(context.getBean("test")).isEqualTo("42");
@ -453,7 +460,7 @@ class GenericApplicationContextTests { @@ -453,7 +460,7 @@ class GenericApplicationContextTests {
void refreshForAotFailsOnAnActiveContext() {
GenericApplicationContext context = new GenericApplicationContext();
context.refresh();
assertThatIllegalStateException().isThrownBy(context::refreshForAotProcessing)
assertThatIllegalStateException().isThrownBy(() -> context.refreshForAotProcessing(new RuntimeHints()))
.withMessageContaining("does not support multiple refresh attempts");
context.close();
}
@ -463,7 +470,7 @@ class GenericApplicationContextTests { @@ -463,7 +470,7 @@ class GenericApplicationContextTests {
GenericApplicationContext context = new GenericApplicationContext();
context.registerBeanDefinition("genericFactoryBean",
new RootBeanDefinition(TestAotFactoryBean.class));
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
context.close();
}
@ -473,7 +480,32 @@ class GenericApplicationContextTests { @@ -473,7 +480,32 @@ class GenericApplicationContextTests {
context.registerBeanDefinition("test", BeanDefinitionBuilder.rootBeanDefinition(String.class, () -> {
throw new IllegalStateException("Should not be invoked");
}).getBeanDefinition());
context.refreshForAotProcessing();
context.refreshForAotProcessing(new RuntimeHints());
context.close();
}
@Test
void refreshForAotRegisterProxyHint() {
GenericApplicationContext context = new GenericApplicationContext();
context.registerBeanDefinition("bpp", BeanDefinitionBuilder.rootBeanDefinition(
SmartInstantiationAwareBeanPostProcessor.class, () -> new SmartInstantiationAwareBeanPostProcessor() {
@Override
public Class<?> determineBeanType(Class<?> beanClass, String beanName) throws BeansException {
if (beanClass.isInterface()) {
return Proxy.newProxyInstance(GenericApplicationContextTests.class.getClassLoader(),
new Class[] { Map.class, DecoratingProxy.class }, (proxy, method, args) -> null).getClass();
}
else {
return beanClass;
}
}
})
.setRole(BeanDefinition.ROLE_INFRASTRUCTURE).getBeanDefinition());
context.registerBeanDefinition("map", BeanDefinitionBuilder.rootBeanDefinition(Map.class,
HashMap::new).getBeanDefinition());
RuntimeHints runtimeHints = new RuntimeHints();
context.refreshForAotProcessing(runtimeHints);
assertThat(RuntimeHintsPredicates.proxies().forInterfaces(Map.class, DecoratingProxy.class)).accepts(runtimeHints);
context.close();
}

Loading…
Cancel
Save