diff --git a/org.springframework.context/src/main/java/org/springframework/context/support/DefaultLifecycleProcessor.java b/org.springframework.context/src/main/java/org/springframework/context/support/DefaultLifecycleProcessor.java index 2bc386d114..1a7121158b 100644 --- a/org.springframework.context/src/main/java/org/springframework/context/support/DefaultLifecycleProcessor.java +++ b/org.springframework.context/src/main/java/org/springframework/context/support/DefaultLifecycleProcessor.java @@ -30,6 +30,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.BeanFactoryUtils; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.context.Lifecycle; import org.springframework.context.LifecycleProcessor; @@ -216,17 +217,22 @@ public class DefaultLifecycleProcessor implements LifecycleProcessor, BeanFactor } } + /** + * Retrieve all applicable Lifecycle beans: all singletons that have already been created, + * as well as all SmartLifecycle beans (even if they are marked as lazy-init). + */ private Map getLifecycleBeans() { Map beans = new LinkedHashMap(); - Map smartLifecycles = - this.beanFactory.getBeansOfType(SmartLifecycle.class, false, true); - beans.putAll(smartLifecycles); - String[] singletonNames = this.beanFactory.getSingletonNames(); - for (String beanName : singletonNames) { - if (!beans.containsKey(beanName)) { - Object bean = this.beanFactory.getSingleton(beanName); - if (bean instanceof Lifecycle && !this.equals(bean)) { - beans.put(beanName, (Lifecycle) bean); + String[] beanNames = this.beanFactory.getBeanNamesForType(Lifecycle.class, false, false); + for (String beanName : beanNames) { + String beanNameToRegister = BeanFactoryUtils.transformedBeanName(beanName); + String beanNameToCheck = (this.beanFactory.isFactoryBean(beanNameToRegister) ? + BeanFactory.FACTORY_BEAN_PREFIX + beanName : beanName); + if (this.beanFactory.containsSingleton(beanNameToRegister) || + SmartLifecycle.class.isAssignableFrom(this.beanFactory.getType(beanNameToCheck))) { + Lifecycle bean = this.beanFactory.getBean(beanNameToCheck, Lifecycle.class); + if (bean != this) { + beans.put(beanNameToRegister, bean); } } } diff --git a/org.springframework.context/src/test/java/org/springframework/context/support/DefaultLifecycleProcessorTests.java b/org.springframework.context/src/test/java/org/springframework/context/support/DefaultLifecycleProcessorTests.java index 955ad0f8f7..b9a37d9617 100644 --- a/org.springframework.context/src/test/java/org/springframework/context/support/DefaultLifecycleProcessorTests.java +++ b/org.springframework.context/src/test/java/org/springframework/context/support/DefaultLifecycleProcessorTests.java @@ -22,6 +22,7 @@ import static org.junit.Assert.*; import org.junit.Test; import org.springframework.beans.DirectFieldAccessor; +import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.Lifecycle; @@ -85,6 +86,19 @@ public class DefaultLifecycleProcessorTests { assertFalse(bean.isRunning()); } + @Test + public void singleSmartLifecycleAutoStartupWithLazyInitFactoryBean() throws Exception { + StaticApplicationContext context = new StaticApplicationContext(); + RootBeanDefinition bd = new RootBeanDefinition(DummySmartLifecycleFactoryBean.class); + bd.setLazyInit(true); + context.registerBeanDefinition("bean", bd); + context.refresh(); + DummySmartLifecycleFactoryBean bean = context.getBean("&bean", DummySmartLifecycleFactoryBean.class); + assertTrue(bean.isRunning()); + context.stop(); + assertFalse(bean.isRunning()); + } + @Test public void singleSmartLifecycleWithoutAutoStartup() throws Exception { CopyOnWriteArrayList startedBeans = new CopyOnWriteArrayList(); @@ -654,4 +668,49 @@ public class DefaultLifecycleProcessorTests { } } + + public static class DummySmartLifecycleFactoryBean implements FactoryBean, SmartLifecycle { + + public boolean running = false; + + DummySmartLifecycleBean bean = new DummySmartLifecycleBean(); + + public Object getObject() throws Exception { + return this.bean; + } + + public Class getObjectType() { + return DummySmartLifecycleBean.class; + } + + public boolean isSingleton() { + return true; + } + + public boolean isAutoStartup() { + return true; + } + + public void stop(Runnable callback) { + this.running = false; + callback.run(); + } + + public void start() { + this.running = true; + } + + public void stop() { + this.running = false; + } + + public boolean isRunning() { + return this.running; + } + + public int getPhase() { + return 0; + } + } + }