Browse Source

Defensively check for pre-resolved FactoryBean.getObject() results in circular reference scenarios

Issue: SPR-11937
pull/578/head
Juergen Hoeller 11 years ago
parent
commit
d870b382da
  1. 51
      spring-beans/src/main/java/org/springframework/beans/factory/support/FactoryBeanRegistrySupport.java
  2. 25
      spring-beans/src/test/java/org/springframework/beans/factory/FactoryBeanTests-circular.xml
  3. 140
      spring-beans/src/test/java/org/springframework/beans/factory/FactoryBeanTests.java

51
spring-beans/src/main/java/org/springframework/beans/factory/support/FactoryBeanRegistrySupport.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2013 the original author or authors. * Copyright 2002-2014 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -90,7 +90,7 @@ public abstract class FactoryBeanRegistrySupport extends DefaultSingletonBeanReg
* Obtain an object to expose from the given FactoryBean. * Obtain an object to expose from the given FactoryBean.
* @param factory the FactoryBean instance * @param factory the FactoryBean instance
* @param beanName the name of the bean * @param beanName the name of the bean
* @param shouldPostProcess whether the bean is subject for post-processing * @param shouldPostProcess whether the bean is subject to post-processing
* @return the object obtained from the FactoryBean * @return the object obtained from the FactoryBean
* @throws BeanCreationException if FactoryBean object creation failed * @throws BeanCreationException if FactoryBean object creation failed
* @see org.springframework.beans.factory.FactoryBean#getObject() * @see org.springframework.beans.factory.FactoryBean#getObject()
@ -100,14 +100,40 @@ public abstract class FactoryBeanRegistrySupport extends DefaultSingletonBeanReg
synchronized (getSingletonMutex()) { synchronized (getSingletonMutex()) {
Object object = this.factoryBeanObjectCache.get(beanName); Object object = this.factoryBeanObjectCache.get(beanName);
if (object == null) { if (object == null) {
object = doGetObjectFromFactoryBean(factory, beanName, shouldPostProcess); object = doGetObjectFromFactoryBean(factory, beanName);
this.factoryBeanObjectCache.put(beanName, (object != null ? object : NULL_OBJECT)); // Only post-process and store if not put there already during getObject() call above
// (e.g. because of circular reference processing triggered by custom getBean calls)
Object alreadyThere = this.factoryBeanObjectCache.get(beanName);
if (alreadyThere != null) {
object = alreadyThere;
}
else {
if (object != null && shouldPostProcess) {
try {
object = postProcessObjectFromFactoryBean(object, beanName);
}
catch (Throwable ex) {
throw new BeanCreationException(beanName,
"Post-processing of FactoryBean's singleton object failed", ex);
}
}
this.factoryBeanObjectCache.put(beanName, (object != null ? object : NULL_OBJECT));
}
} }
return (object != NULL_OBJECT ? object : null); return (object != NULL_OBJECT ? object : null);
} }
} }
else { else {
return doGetObjectFromFactoryBean(factory, beanName, shouldPostProcess); Object object = doGetObjectFromFactoryBean(factory, beanName);
if (object != null && shouldPostProcess) {
try {
object = postProcessObjectFromFactoryBean(object, beanName);
}
catch (Throwable ex) {
throw new BeanCreationException(beanName, "Post-processing of FactoryBean's object failed", ex);
}
}
return object;
} }
} }
@ -115,13 +141,11 @@ public abstract class FactoryBeanRegistrySupport extends DefaultSingletonBeanReg
* Obtain an object to expose from the given FactoryBean. * Obtain an object to expose from the given FactoryBean.
* @param factory the FactoryBean instance * @param factory the FactoryBean instance
* @param beanName the name of the bean * @param beanName the name of the bean
* @param shouldPostProcess whether the bean is subject for post-processing
* @return the object obtained from the FactoryBean * @return the object obtained from the FactoryBean
* @throws BeanCreationException if FactoryBean object creation failed * @throws BeanCreationException if FactoryBean object creation failed
* @see org.springframework.beans.factory.FactoryBean#getObject() * @see org.springframework.beans.factory.FactoryBean#getObject()
*/ */
private Object doGetObjectFromFactoryBean( private Object doGetObjectFromFactoryBean(final FactoryBean<?> factory, final String beanName)
final FactoryBean<?> factory, final String beanName, final boolean shouldPostProcess)
throws BeanCreationException { throws BeanCreationException {
Object object; Object object;
@ -151,23 +175,12 @@ public abstract class FactoryBeanRegistrySupport extends DefaultSingletonBeanReg
throw new BeanCreationException(beanName, "FactoryBean threw exception on object creation", ex); throw new BeanCreationException(beanName, "FactoryBean threw exception on object creation", ex);
} }
// Do not accept a null value for a FactoryBean that's not fully // Do not accept a null value for a FactoryBean that's not fully
// initialized yet: Many FactoryBeans just return null then. // initialized yet: Many FactoryBeans just return null then.
if (object == null && isSingletonCurrentlyInCreation(beanName)) { if (object == null && isSingletonCurrentlyInCreation(beanName)) {
throw new BeanCurrentlyInCreationException( throw new BeanCurrentlyInCreationException(
beanName, "FactoryBean which is currently in creation returned null from getObject"); beanName, "FactoryBean which is currently in creation returned null from getObject");
} }
if (object != null && shouldPostProcess) {
try {
object = postProcessObjectFromFactoryBean(object, beanName);
}
catch (Throwable ex) {
throw new BeanCreationException(beanName, "Post-processing of the FactoryBean's object failed", ex);
}
}
return object; return object;
} }

25
spring-beans/src/test/java/org/springframework/beans/factory/FactoryBeanTests-circular.xml

@ -0,0 +1,25 @@
<?xml version="1.0" encoding="ISO-8859-1"?>
<beans xmlns="http://www.springframework.org/schema/beans"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://www.springframework.org/schema/beans
http://www.springframework.org/schema/beans/spring-beans-3.0.xsd">
<bean id="bean1" class="org.springframework.beans.factory.FactoryBeanTests$PassThroughFactoryBean" primary="true">
<constructor-arg value="org.springframework.beans.factory.FactoryBeanTests$BeanImpl1"/>
<property name="instanceName" value="beanImpl1"/>
</bean>
<bean id="beanImpl1" class="org.springframework.beans.factory.FactoryBeanTests$BeanImpl1">
<property name="impl2" ref="bean2"/>
</bean>
<bean id="bean2" class="org.springframework.beans.factory.FactoryBeanTests$PassThroughFactoryBean" primary="true">
<constructor-arg value="org.springframework.beans.factory.FactoryBeanTests$BeanImpl2"/>
<property name="instanceName" value="beanImpl2"/>
</bean>
<bean id="beanImpl2" class="org.springframework.beans.factory.FactoryBeanTests$BeanImpl2">
<property name="impl1" ref="bean1"/>
</bean>
</beans>

140
spring-beans/src/test/java/org/springframework/beans/factory/FactoryBeanTests.java

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2013 the original author or authors. * Copyright 2002-2014 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,19 +16,24 @@
package org.springframework.beans.factory; package org.springframework.beans.factory;
import static org.junit.Assert.assertEquals; import java.util.HashMap;
import static org.junit.Assert.assertNull; import java.util.Map;
import static org.junit.Assert.assertSame; import java.util.concurrent.atomic.AtomicInteger;
import static org.springframework.tests.TestResourceUtils.qualifiedResource;
import org.junit.Test; import org.junit.Test;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; import org.springframework.beans.factory.xml.XmlBeanDefinitionReader;
import org.springframework.core.io.Resource; import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import static org.junit.Assert.*;
import static org.springframework.tests.TestResourceUtils.*;
/** /**
* @author Rob Harrop * @author Rob Harrop
* @author Juergen Hoeller * @author Juergen Hoeller
@ -40,6 +45,7 @@ public final class FactoryBeanTests {
private static final Resource RETURNS_NULL_CONTEXT = qualifiedResource(CLASS, "returnsNull.xml"); private static final Resource RETURNS_NULL_CONTEXT = qualifiedResource(CLASS, "returnsNull.xml");
private static final Resource WITH_AUTOWIRING_CONTEXT = qualifiedResource(CLASS, "withAutowiring.xml"); private static final Resource WITH_AUTOWIRING_CONTEXT = qualifiedResource(CLASS, "withAutowiring.xml");
private static final Resource ABSTRACT_CONTEXT = qualifiedResource(CLASS, "abstract.xml"); private static final Resource ABSTRACT_CONTEXT = qualifiedResource(CLASS, "abstract.xml");
private static final Resource CIRCULAR_CONTEXT = qualifiedResource(CLASS, "circular.xml");
@Test @Test
public void testFactoryBeanReturnsNull() throws Exception { public void testFactoryBeanReturnsNull() throws Exception {
@ -96,6 +102,23 @@ public final class FactoryBeanTests {
factory.getBeansOfType(AbstractFactoryBean.class); factory.getBeansOfType(AbstractFactoryBean.class);
} }
@Test
public void testCircularReferenceWithPostProcessor() {
DefaultListableBeanFactory factory = new DefaultListableBeanFactory();
new XmlBeanDefinitionReader(factory).loadBeanDefinitions(CIRCULAR_CONTEXT);
CountingPostProcessor counter = new CountingPostProcessor();
factory.addBeanPostProcessor(counter);
BeanImpl1 impl1 = factory.getBean(BeanImpl1.class);
assertNotNull(impl1);
assertNotNull(impl1.getImpl2());
assertNotNull(impl1.getImpl2());
assertSame(impl1, impl1.getImpl2().getImpl1());
assertEquals(1, counter.getCount("bean1"));
assertEquals(1, counter.getCount("bean2"));
}
public static class NullReturningFactoryBean implements FactoryBean<Object> { public static class NullReturningFactoryBean implements FactoryBean<Object> {
@ -193,7 +216,114 @@ public final class FactoryBeanTests {
} }
} }
public abstract static class AbstractFactoryBean implements FactoryBean<Object> { public abstract static class AbstractFactoryBean implements FactoryBean<Object> {
} }
public static class PassThroughFactoryBean<T> implements FactoryBean<T>, BeanFactoryAware {
private Class<T> type;
private String instanceName;
private BeanFactory beanFactory;
private T instance;
public PassThroughFactoryBean(Class<T> type) {
this.type = type;
}
public void setInstanceName(String instanceName) {
this.instanceName = instanceName;
}
@Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
this.beanFactory = beanFactory;
}
@Override
public T getObject() {
if (instance == null) {
instance = beanFactory.getBean(instanceName, type);
}
return instance;
}
@Override
public Class<?> getObjectType() {
return type;
}
@Override
public boolean isSingleton() {
return true;
}
}
public static class CountingPostProcessor implements BeanPostProcessor {
private final Map<String, AtomicInteger> count = new HashMap<String, AtomicInteger>();
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) {
return bean;
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) {
if (bean instanceof FactoryBean) {
return bean;
}
AtomicInteger c = count.get(beanName);
if (c == null) {
c = new AtomicInteger(0);
count.put(beanName, c);
}
c.incrementAndGet();
return bean;
}
public int getCount(String beanName) {
AtomicInteger c = count.get(beanName);
if (c != null) {
return c.intValue();
}
else {
return 0;
}
}
}
public static class BeanImpl1 {
private BeanImpl2 impl2;
public BeanImpl2 getImpl2() {
return impl2;
}
public void setImpl2(BeanImpl2 impl2) {
this.impl2 = impl2;
}
}
public static class BeanImpl2 {
private BeanImpl1 impl1;
public BeanImpl1 getImpl1() {
return impl1;
}
public void setImpl1(BeanImpl1 impl1) {
this.impl1 = impl1;
}
}
} }

Loading…
Cancel
Save