From d870b382da1c914d1095b88a9fb92835678f53d7 Mon Sep 17 00:00:00 2001 From: Juergen Hoeller Date: Tue, 1 Jul 2014 23:50:17 +0200 Subject: [PATCH] Defensively check for pre-resolved FactoryBean.getObject() results in circular reference scenarios Issue: SPR-11937 --- .../support/FactoryBeanRegistrySupport.java | 51 ++++--- .../factory/FactoryBeanTests-circular.xml | 25 ++++ .../beans/factory/FactoryBeanTests.java | 140 +++++++++++++++++- 3 files changed, 192 insertions(+), 24 deletions(-) create mode 100644 spring-beans/src/test/java/org/springframework/beans/factory/FactoryBeanTests-circular.xml diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/FactoryBeanRegistrySupport.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/FactoryBeanRegistrySupport.java index afda1d4c60..53dc17657e 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/FactoryBeanRegistrySupport.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/FactoryBeanRegistrySupport.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -90,7 +90,7 @@ public abstract class FactoryBeanRegistrySupport extends DefaultSingletonBeanReg * Obtain an object to expose from the given FactoryBean. * @param factory the FactoryBean instance * @param beanName the name of the bean - * @param shouldPostProcess whether the bean is subject for post-processing + * @param shouldPostProcess whether the bean is subject to post-processing * @return the object obtained from the FactoryBean * @throws BeanCreationException if FactoryBean object creation failed * @see org.springframework.beans.factory.FactoryBean#getObject() @@ -100,14 +100,40 @@ public abstract class FactoryBeanRegistrySupport extends DefaultSingletonBeanReg synchronized (getSingletonMutex()) { Object object = this.factoryBeanObjectCache.get(beanName); if (object == null) { - object = doGetObjectFromFactoryBean(factory, beanName, shouldPostProcess); - this.factoryBeanObjectCache.put(beanName, (object != null ? object : NULL_OBJECT)); + object = doGetObjectFromFactoryBean(factory, beanName); + // Only post-process and store if not put there already during getObject() call above + // (e.g. because of circular reference processing triggered by custom getBean calls) + Object alreadyThere = this.factoryBeanObjectCache.get(beanName); + if (alreadyThere != null) { + object = alreadyThere; + } + else { + if (object != null && shouldPostProcess) { + try { + object = postProcessObjectFromFactoryBean(object, beanName); + } + catch (Throwable ex) { + throw new BeanCreationException(beanName, + "Post-processing of FactoryBean's singleton object failed", ex); + } + } + this.factoryBeanObjectCache.put(beanName, (object != null ? object : NULL_OBJECT)); + } } return (object != NULL_OBJECT ? object : null); } } else { - return doGetObjectFromFactoryBean(factory, beanName, shouldPostProcess); + Object object = doGetObjectFromFactoryBean(factory, beanName); + if (object != null && shouldPostProcess) { + try { + object = postProcessObjectFromFactoryBean(object, beanName); + } + catch (Throwable ex) { + throw new BeanCreationException(beanName, "Post-processing of FactoryBean's object failed", ex); + } + } + return object; } } @@ -115,13 +141,11 @@ public abstract class FactoryBeanRegistrySupport extends DefaultSingletonBeanReg * Obtain an object to expose from the given FactoryBean. * @param factory the FactoryBean instance * @param beanName the name of the bean - * @param shouldPostProcess whether the bean is subject for post-processing * @return the object obtained from the FactoryBean * @throws BeanCreationException if FactoryBean object creation failed * @see org.springframework.beans.factory.FactoryBean#getObject() */ - private Object doGetObjectFromFactoryBean( - final FactoryBean factory, final String beanName, final boolean shouldPostProcess) + private Object doGetObjectFromFactoryBean(final FactoryBean factory, final String beanName) throws BeanCreationException { Object object; @@ -151,23 +175,12 @@ public abstract class FactoryBeanRegistrySupport extends DefaultSingletonBeanReg throw new BeanCreationException(beanName, "FactoryBean threw exception on object creation", ex); } - // Do not accept a null value for a FactoryBean that's not fully // initialized yet: Many FactoryBeans just return null then. if (object == null && isSingletonCurrentlyInCreation(beanName)) { throw new BeanCurrentlyInCreationException( beanName, "FactoryBean which is currently in creation returned null from getObject"); } - - if (object != null && shouldPostProcess) { - try { - object = postProcessObjectFromFactoryBean(object, beanName); - } - catch (Throwable ex) { - throw new BeanCreationException(beanName, "Post-processing of the FactoryBean's object failed", ex); - } - } - return object; } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/FactoryBeanTests-circular.xml b/spring-beans/src/test/java/org/springframework/beans/factory/FactoryBeanTests-circular.xml new file mode 100644 index 0000000000..1c328aa8d4 --- /dev/null +++ b/spring-beans/src/test/java/org/springframework/beans/factory/FactoryBeanTests-circular.xml @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/FactoryBeanTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/FactoryBeanTests.java index 5d7808f426..776cb3763e 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/FactoryBeanTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/FactoryBeanTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,19 +16,24 @@ package org.springframework.beans.factory; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.springframework.tests.TestResourceUtils.qualifiedResource; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; + +import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.config.BeanPostProcessor; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; import org.springframework.core.io.Resource; import org.springframework.stereotype.Component; import org.springframework.util.Assert; +import static org.junit.Assert.*; +import static org.springframework.tests.TestResourceUtils.*; + /** * @author Rob Harrop * @author Juergen Hoeller @@ -40,6 +45,7 @@ public final class FactoryBeanTests { private static final Resource RETURNS_NULL_CONTEXT = qualifiedResource(CLASS, "returnsNull.xml"); private static final Resource WITH_AUTOWIRING_CONTEXT = qualifiedResource(CLASS, "withAutowiring.xml"); private static final Resource ABSTRACT_CONTEXT = qualifiedResource(CLASS, "abstract.xml"); + private static final Resource CIRCULAR_CONTEXT = qualifiedResource(CLASS, "circular.xml"); @Test public void testFactoryBeanReturnsNull() throws Exception { @@ -96,6 +102,23 @@ public final class FactoryBeanTests { factory.getBeansOfType(AbstractFactoryBean.class); } + @Test + public void testCircularReferenceWithPostProcessor() { + DefaultListableBeanFactory factory = new DefaultListableBeanFactory(); + new XmlBeanDefinitionReader(factory).loadBeanDefinitions(CIRCULAR_CONTEXT); + + CountingPostProcessor counter = new CountingPostProcessor(); + factory.addBeanPostProcessor(counter); + + BeanImpl1 impl1 = factory.getBean(BeanImpl1.class); + assertNotNull(impl1); + assertNotNull(impl1.getImpl2()); + assertNotNull(impl1.getImpl2()); + assertSame(impl1, impl1.getImpl2().getImpl1()); + assertEquals(1, counter.getCount("bean1")); + assertEquals(1, counter.getCount("bean2")); + } + public static class NullReturningFactoryBean implements FactoryBean { @@ -193,7 +216,114 @@ public final class FactoryBeanTests { } } + public abstract static class AbstractFactoryBean implements FactoryBean { } + + public static class PassThroughFactoryBean implements FactoryBean, BeanFactoryAware { + + private Class type; + + private String instanceName; + + private BeanFactory beanFactory; + + private T instance; + + public PassThroughFactoryBean(Class type) { + this.type = type; + } + + public void setInstanceName(String instanceName) { + this.instanceName = instanceName; + } + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.beanFactory = beanFactory; + } + + + @Override + public T getObject() { + if (instance == null) { + instance = beanFactory.getBean(instanceName, type); + } + return instance; + } + + @Override + public Class getObjectType() { + return type; + } + + @Override + public boolean isSingleton() { + return true; + } + } + + + public static class CountingPostProcessor implements BeanPostProcessor { + + private final Map count = new HashMap(); + + @Override + public Object postProcessBeforeInitialization(Object bean, String beanName) { + return bean; + } + + @Override + public Object postProcessAfterInitialization(Object bean, String beanName) { + if (bean instanceof FactoryBean) { + return bean; + } + AtomicInteger c = count.get(beanName); + if (c == null) { + c = new AtomicInteger(0); + count.put(beanName, c); + } + c.incrementAndGet(); + return bean; + } + + public int getCount(String beanName) { + AtomicInteger c = count.get(beanName); + if (c != null) { + return c.intValue(); + } + else { + return 0; + } + } + } + + + public static class BeanImpl1 { + + private BeanImpl2 impl2; + + public BeanImpl2 getImpl2() { + return impl2; + } + + public void setImpl2(BeanImpl2 impl2) { + this.impl2 = impl2; + } + } + + + public static class BeanImpl2 { + + private BeanImpl1 impl1; + + public BeanImpl1 getImpl1() { + return impl1; + } + + public void setImpl1(BeanImpl1 impl1) { + this.impl1 = impl1; + } + } + }