From ebf6de8f5d7acc21e21138af897d511a25038d3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Deleuze?= Date: Fri, 19 Aug 2022 17:28:09 +0200 Subject: [PATCH] Infer JDK dynamic proxies for Spring beans See gh-28980 --- .../aot/ApplicationContextAotGenerator.java | 2 +- .../support/GenericApplicationContext.java | 12 +++- ...notationConfigApplicationContextTests.java | 7 ++- .../GenericApplicationContextTests.java | 60 ++++++++++++++----- 4 files changed, 60 insertions(+), 21 deletions(-) diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java index f7df128ca6..6a7e568ec2 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java @@ -51,7 +51,7 @@ public class ApplicationContextAotGenerator { public ClassName processAheadOfTime(GenericApplicationContext applicationContext, GenerationContext generationContext) { return withGeneratedClassHandler(new GeneratedClassHandler(generationContext), () -> { - applicationContext.refreshForAotProcessing(); + applicationContext.refreshForAotProcessing(generationContext.getRuntimeHints()); DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory(); ApplicationContextInitializationCodeGenerator codeGenerator = new ApplicationContextInitializationCodeGenerator(generationContext); diff --git a/spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java b/spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java index dc332d97b8..4cb5f7260f 100644 --- a/spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java +++ b/spring-context/src/main/java/org/springframework/context/support/GenericApplicationContext.java @@ -18,10 +18,12 @@ package org.springframework.context.support; import java.io.IOException; import java.lang.reflect.Constructor; +import java.lang.reflect.Proxy; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; +import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.BeanUtils; import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanDefinitionStoreException; @@ -394,12 +396,13 @@ public class GenericApplicationContext extends AbstractApplicationContext implem * processing that optimizes the application context, typically at build time. *

In this mode, only {@link BeanDefinitionRegistryPostProcessor} and * {@link MergedBeanDefinitionPostProcessor} are invoked. + * @param runtimeHints the runtime hints * @throws BeansException if the bean factory could not be initialized * @throws IllegalStateException if already initialized and multiple refresh * attempts are not supported * @since 6.0 */ - public void refreshForAotProcessing() { + public void refreshForAotProcessing(RuntimeHints runtimeHints) { if (logger.isDebugEnabled()) { logger.debug("Preparing bean factory for AOT processing"); } @@ -410,7 +413,7 @@ public class GenericApplicationContext extends AbstractApplicationContext implem invokeBeanFactoryPostProcessors(this.beanFactory); this.beanFactory.freezeConfiguration(); PostProcessorRegistrationDelegate.invokeMergedBeanDefinitionPostProcessors(this.beanFactory); - preDetermineBeanTypes(); + preDetermineBeanTypes(runtimeHints); } /** @@ -418,7 +421,7 @@ public class GenericApplicationContext extends AbstractApplicationContext implem * @see org.springframework.beans.factory.BeanFactory#getType * @see SmartInstantiationAwareBeanPostProcessor#determineBeanType */ - private void preDetermineBeanTypes() { + private void preDetermineBeanTypes(RuntimeHints runtimeHints) { List bpps = PostProcessorRegistrationDelegate.loadBeanPostProcessors( this.beanFactory, SmartInstantiationAwareBeanPostProcessor.class); @@ -427,6 +430,9 @@ public class GenericApplicationContext extends AbstractApplicationContext implem if (beanType != null) { for (SmartInstantiationAwareBeanPostProcessor bpp : bpps) { beanType = bpp.determineBeanType(beanType, beanName); + if (Proxy.isProxyClass(beanType)) { + runtimeHints.proxies().registerJdkProxy(beanType.getInterfaces()); + } } } } diff --git a/spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java b/spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java index 57831c404d..080149d761 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/AnnotationConfigApplicationContextTests.java @@ -21,6 +21,7 @@ import java.util.regex.Pattern; import org.junit.jupiter.api.Test; +import org.springframework.aot.hint.RuntimeHints; import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.annotation.Autowired; @@ -426,7 +427,7 @@ class AnnotationConfigApplicationContextTests { void refreshForAotProcessingWithConfiguration() { AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); context.register(Config.class); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); assertThat(context.getBeanFactory().getBeanDefinitionNames()).contains( "annotationConfigApplicationContextTests.Config", "testBean"); } @@ -435,7 +436,7 @@ class AnnotationConfigApplicationContextTests { void refreshForAotCanInstantiateBeanWithAutowiredApplicationContext() { AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); context.register(BeanD.class); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); BeanD bean = context.getBean(BeanD.class); assertThat(bean.applicationContext).isSameAs(context); } @@ -444,7 +445,7 @@ class AnnotationConfigApplicationContextTests { void refreshForAotCanInstantiateBeanWithFieldAutowiredApplicationContext() { AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); context.register(BeanB.class); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); BeanB bean = context.getBean(BeanB.class); assertThat(bean.applicationContext).isSameAs(context); } diff --git a/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java b/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java index aa85cefe45..7af3bbc20c 100644 --- a/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java +++ b/spring-context/src/test/java/org/springframework/context/support/GenericApplicationContextTests.java @@ -16,18 +16,24 @@ package org.springframework.context.support; +import java.lang.reflect.Proxy; import java.nio.file.InvalidPathException; +import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.OS; import org.mockito.ArgumentCaptor; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; import org.springframework.beans.BeansException; import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.beans.factory.config.AbstractFactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.config.SmartInstantiationAwareBeanPostProcessor; import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.GenericBeanDefinition; @@ -35,6 +41,7 @@ import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcess import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; +import org.springframework.core.DecoratingProxy; import org.springframework.core.env.ConfigurableEnvironment; import org.springframework.core.env.Environment; import org.springframework.core.io.ByteArrayResource; @@ -296,7 +303,7 @@ class GenericApplicationContextTests { void refreshForAotSetsContextActive() { GenericApplicationContext context = new GenericApplicationContext(); assertThat(context.isActive()).isFalse(); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); assertThat(context.isActive()).isTrue(); context.close(); } @@ -306,7 +313,7 @@ class GenericApplicationContextTests { ConfigurableEnvironment environment = mock(ConfigurableEnvironment.class); GenericApplicationContext context = new GenericApplicationContext(); context.setEnvironment(environment); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); assertThat(context.getBean(Environment.class)).isEqualTo(environment); context.close(); } @@ -315,7 +322,7 @@ class GenericApplicationContextTests { void refreshForAotLoadsBeanClassName() { GenericApplicationContext context = new GenericApplicationContext(); context.registerBeanDefinition("number", new RootBeanDefinition("java.lang.Integer")); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); assertThat(getBeanDefinition(context, "number").getBeanClass()).isEqualTo(Integer.class); context.close(); } @@ -328,7 +335,7 @@ class GenericApplicationContextTests { innerBeanDefinition.setBeanClassName("java.lang.Integer"); beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, innerBeanDefinition); context.registerBeanDefinition("test",beanDefinition); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); RootBeanDefinition bd = getBeanDefinition(context, "test"); GenericBeanDefinition value = (GenericBeanDefinition) bd.getConstructorArgumentValues() .getIndexedArgumentValue(0, GenericBeanDefinition.class).getValue(); @@ -345,7 +352,7 @@ class GenericApplicationContextTests { innerBeanDefinition.setBeanClassName("java.lang.Integer"); beanDefinition.getPropertyValues().add("inner", innerBeanDefinition); context.registerBeanDefinition("test",beanDefinition); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); RootBeanDefinition bd = getBeanDefinition(context, "test"); GenericBeanDefinition value = (GenericBeanDefinition) bd.getPropertyValues().get("inner"); assertThat(value.hasBeanClass()).isTrue(); @@ -358,7 +365,7 @@ class GenericApplicationContextTests { GenericApplicationContext context = new GenericApplicationContext(); BeanFactoryPostProcessor bfpp = mock(BeanFactoryPostProcessor.class); context.addBeanFactoryPostProcessor(bfpp); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); verify(bfpp).postProcessBeanFactory(context.getBeanFactory()); context.close(); } @@ -369,7 +376,7 @@ class GenericApplicationContextTests { context.registerBeanDefinition("test", new RootBeanDefinition(String.class)); context.registerBeanDefinition("number", new RootBeanDefinition("java.lang.Integer")); MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), String.class, "test"); verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "number"), Integer.class, "number"); context.close(); @@ -384,7 +391,7 @@ class GenericApplicationContextTests { beanDefinition.getConstructorArgumentValues().addIndexedArgumentValue(0, innerBeanDefinition); context.registerBeanDefinition("test", beanDefinition); MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), BeanD.class, "test"); verify(bpp).postProcessMergedBeanDefinition(any(RootBeanDefinition.class), eq(Integer.class), captor.capture()); @@ -401,7 +408,7 @@ class GenericApplicationContextTests { beanDefinition.getPropertyValues().add("counter", innerBeanDefinition); context.registerBeanDefinition("test", beanDefinition); MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); verify(bpp).postProcessMergedBeanDefinition(getBeanDefinition(context, "test"), BeanD.class, "test"); verify(bpp).postProcessMergedBeanDefinition(any(RootBeanDefinition.class), eq(Integer.class), captor.capture()); @@ -414,7 +421,7 @@ class GenericApplicationContextTests { GenericApplicationContext context = new GenericApplicationContext(); context.registerBeanDefinition("test", new RootBeanDefinition(String.class)); MergedBeanDefinitionPostProcessor bpp = registerMockMergedBeanDefinitionPostProcessor(context); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); RootBeanDefinition mergedBeanDefinition = getBeanDefinition(context, "test"); verify(bpp).postProcessMergedBeanDefinition(mergedBeanDefinition, String.class, "test"); context.getBeanFactory().clearMetadataCache(); @@ -442,7 +449,7 @@ class GenericApplicationContextTests { AbstractBeanDefinition bd = BeanDefinitionBuilder.rootBeanDefinition(String.class) .addConstructorArgValue("value").getBeanDefinition(); context.registerBeanDefinition("test", bd); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); assertThat(context.getBeanFactory().getMergedBeanDefinition("test") .hasAttribute("mbdppCalled")).isTrue(); assertThat(context.getBean("test")).isEqualTo("42"); @@ -453,7 +460,7 @@ class GenericApplicationContextTests { void refreshForAotFailsOnAnActiveContext() { GenericApplicationContext context = new GenericApplicationContext(); context.refresh(); - assertThatIllegalStateException().isThrownBy(context::refreshForAotProcessing) + assertThatIllegalStateException().isThrownBy(() -> context.refreshForAotProcessing(new RuntimeHints())) .withMessageContaining("does not support multiple refresh attempts"); context.close(); } @@ -463,7 +470,7 @@ class GenericApplicationContextTests { GenericApplicationContext context = new GenericApplicationContext(); context.registerBeanDefinition("genericFactoryBean", new RootBeanDefinition(TestAotFactoryBean.class)); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); context.close(); } @@ -473,7 +480,32 @@ class GenericApplicationContextTests { context.registerBeanDefinition("test", BeanDefinitionBuilder.rootBeanDefinition(String.class, () -> { throw new IllegalStateException("Should not be invoked"); }).getBeanDefinition()); - context.refreshForAotProcessing(); + context.refreshForAotProcessing(new RuntimeHints()); + context.close(); + } + + @Test + void refreshForAotRegisterProxyHint() { + GenericApplicationContext context = new GenericApplicationContext(); + context.registerBeanDefinition("bpp", BeanDefinitionBuilder.rootBeanDefinition( + SmartInstantiationAwareBeanPostProcessor.class, () -> new SmartInstantiationAwareBeanPostProcessor() { + @Override + public Class determineBeanType(Class beanClass, String beanName) throws BeansException { + if (beanClass.isInterface()) { + return Proxy.newProxyInstance(GenericApplicationContextTests.class.getClassLoader(), + new Class[] { Map.class, DecoratingProxy.class }, (proxy, method, args) -> null).getClass(); + } + else { + return beanClass; + } + } + }) + .setRole(BeanDefinition.ROLE_INFRASTRUCTURE).getBeanDefinition()); + context.registerBeanDefinition("map", BeanDefinitionBuilder.rootBeanDefinition(Map.class, + HashMap::new).getBeanDefinition()); + RuntimeHints runtimeHints = new RuntimeHints(); + context.refreshForAotProcessing(runtimeHints); + assertThat(RuntimeHintsPredicates.proxies().forInterfaces(Map.class, DecoratingProxy.class)).accepts(runtimeHints); context.close(); }