Browse Source

Expose env and context in Kotlin beans DSL

This commit introduces a deferred initialization of the declared beans
in order to make it possible to access to the environment (and even
to the context for advanced use-cases) in the beans { } Kotlin DSL.

Issues: SPR-16269, SPR-16412
pull/1789/merge
sdeleuze 7 years ago
parent
commit
97ee94f4ca
  1. 125
      spring-context/src/main/kotlin/org/springframework/context/support/BeanDefinitionDsl.kt
  2. 46
      spring-context/src/test/kotlin/org/springframework/context/support/BeanDefinitionDslTests.kt

125
spring-context/src/main/kotlin/org/springframework/context/support/BeanDefinitionDsl.kt

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -63,7 +63,7 @@ import java.util.function.Supplier @@ -63,7 +63,7 @@ import java.util.function.Supplier
*/
fun beans(init: BeanDefinitionDsl.() -> Unit): BeanDefinitionDsl {
val beans = BeanDefinitionDsl()
beans.init()
beans.init = init
return beans
}
@ -79,12 +79,24 @@ fun beans(init: BeanDefinitionDsl.() -> Unit): BeanDefinitionDsl { @@ -79,12 +79,24 @@ fun beans(init: BeanDefinitionDsl.() -> Unit): BeanDefinitionDsl {
open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) -> Boolean = { true })
: ApplicationContextInitializer<GenericApplicationContext> {
@PublishedApi
internal val registrations = arrayListOf<(GenericApplicationContext) -> Unit>()
@PublishedApi
internal val children = arrayListOf<BeanDefinitionDsl>()
internal lateinit var init: BeanDefinitionDsl.() -> Unit
/**
* Access to the context for advanced use-cases.
* @since 5.1
*/
lateinit var context: GenericApplicationContext
/**
* Shortcut for `context.environment`
* @since 5.1
*/
val env : ConfigurableEnvironment
get() = context.environment
/**
* Scope enum constants.
*/
@ -130,34 +142,6 @@ open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) -> @@ -130,34 +142,6 @@ open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) ->
}
/**
* Provide read access to some application context facilities.
* @constructor Create a new bean definition context.
* @param context the `ApplicationContext` instance to use for retrieving bean
* references, `Environment`, etc.
*/
inner class BeanDefinitionContext(@PublishedApi internal val context: GenericApplicationContext) {
/**
* Get a reference to the bean by type or type + name with the syntax
* `ref<Foo>()` or `ref<Foo>("foo")`. When leveraging Kotlin type inference
* it could be as short as `ref()` or `ref("foo")`.
* @param name the name of the bean to retrieve
* @param T type the bean must match, can be an interface or superclass
*/
inline fun <reified T : Any> ref(name: String? = null) : T = when (name) {
null -> context.getBean(T::class.java)
else -> context.getBean(name, T::class.java)
}
/**
* Get the [ConfigurableEnvironment] associated to the underlying [GenericApplicationContext].
*/
val env : ConfigurableEnvironment
get() = context.environment
}
/**
* Declare a bean definition from the given bean class which can be inferred when possible.
*
@ -177,23 +161,22 @@ open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) -> @@ -177,23 +161,22 @@ open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) ->
isPrimary: Boolean? = null,
autowireMode: Autowire = Autowire.CONSTRUCTOR,
isAutowireCandidate: Boolean? = null) {
registrations.add {
val customizer = BeanDefinitionCustomizer { bd ->
scope?.let { bd.scope = scope.name.toLowerCase() }
isLazyInit?.let { bd.isLazyInit = isLazyInit }
isPrimary?.let { bd.isPrimary = isPrimary }
isAutowireCandidate?.let { bd.isAutowireCandidate = isAutowireCandidate }
if (bd is AbstractBeanDefinition) {
bd.autowireMode = autowireMode.ordinal
}
}
when (name) {
null -> it.registerBean(T::class.java, customizer)
else -> it.registerBean(name, T::class.java, customizer)
val customizer = BeanDefinitionCustomizer { bd ->
scope?.let { bd.scope = scope.name.toLowerCase() }
isLazyInit?.let { bd.isLazyInit = isLazyInit }
isPrimary?.let { bd.isPrimary = isPrimary }
isAutowireCandidate?.let { bd.isAutowireCandidate = isAutowireCandidate }
if (bd is AbstractBeanDefinition) {
bd.autowireMode = autowireMode.ordinal
}
}
when (name) {
null -> context.registerBean(T::class.java, customizer)
else -> context.registerBean(name, T::class.java, customizer)
}
}
/**
@ -216,7 +199,7 @@ open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) -> @@ -216,7 +199,7 @@ open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) ->
isPrimary: Boolean? = null,
autowireMode: Autowire = Autowire.NO,
isAutowireCandidate: Boolean? = null,
crossinline function: BeanDefinitionContext.() -> T) {
crossinline function: () -> T) {
val customizer = BeanDefinitionCustomizer { bd ->
scope?.let { bd.scope = scope.name.toLowerCase() }
@ -228,26 +211,36 @@ open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) -> @@ -228,26 +211,36 @@ open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) ->
}
}
registrations.add {
val beanContext = BeanDefinitionContext(it)
when (name) {
null -> it.registerBean(T::class.java,
Supplier { function.invoke(beanContext) }, customizer)
else -> it.registerBean(name, T::class.java,
Supplier { function.invoke(beanContext) }, customizer)
}
when (name) {
null -> context.registerBean(T::class.java,
Supplier { function.invoke() }, customizer)
else -> context.registerBean(name, T::class.java,
Supplier { function.invoke() }, customizer)
}
}
/**
* Get a reference to the bean by type or type + name with the syntax
* `ref<Foo>()` or `ref<Foo>("foo")`. When leveraging Kotlin type inference
* it could be as short as `ref()` or `ref("foo")`.
* @param name the name of the bean to retrieve
* @param T type the bean must match, can be an interface or superclass
*/
inline fun <reified T : Any> ref(name: String? = null) : T = when (name) {
null -> context.getBean(T::class.java)
else -> context.getBean(name, T::class.java)
}
/**
* Take in account bean definitions enclosed in the provided lambda only when the
* specified profile is active.
*/
fun profile(profile: String, init: BeanDefinitionDsl.() -> Unit): BeanDefinitionDsl {
fun profile(profile: String, init: BeanDefinitionDsl.() -> Unit) {
val beans = BeanDefinitionDsl({ it.activeProfiles.contains(profile) })
beans.init()
beans.init = init
children.add(beans)
return beans
}
/**
@ -257,11 +250,10 @@ open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) -> @@ -257,11 +250,10 @@ open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) ->
* bean definition block
*/
fun environment(condition: ConfigurableEnvironment.() -> Boolean,
init: BeanDefinitionDsl.() -> Unit): BeanDefinitionDsl {
init: BeanDefinitionDsl.() -> Unit) {
val beans = BeanDefinitionDsl(condition::invoke)
beans.init()
beans.init = init
children.add(beans)
return beans
}
/**
@ -269,13 +261,10 @@ open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) -> @@ -269,13 +261,10 @@ open class BeanDefinitionDsl(private val condition: (ConfigurableEnvironment) ->
* @param context The `ApplicationContext` to use for registering the beans
*/
override fun initialize(context: GenericApplicationContext) {
for (registration in registrations) {
if (condition.invoke(context.environment)) {
registration.invoke(context)
}
}
this.context = context
for (child in children) {
child.initialize(context)
}
init()
}
}

46
spring-context/src/test/kotlin/org/springframework/context/support/BeanDefinitionDslTests.kt

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,16 +20,18 @@ import org.junit.Assert.* @@ -20,16 +20,18 @@ import org.junit.Assert.*
import org.junit.Test
import org.springframework.beans.factory.NoSuchBeanDefinitionException
import org.springframework.beans.factory.getBean
import org.springframework.beans.factory.getBeansOfType
import org.springframework.context.support.BeanDefinitionDsl.*
import org.springframework.core.env.SimpleCommandLinePropertySource
import org.springframework.core.env.get
import org.springframework.mock.env.MockPropertySource
@Suppress("UNUSED_EXPRESSION")
class BeanDefinitionDslTests {
@Test
fun `Declare beans with the functional Kotlin DSL`() {
val beans = beans {
val beans = beans {
bean<Foo>()
bean<Bar>("bar", scope = Scope.PROTOTYPE)
bean { Baz(ref()) }
@ -87,7 +89,7 @@ class BeanDefinitionDslTests { @@ -87,7 +89,7 @@ class BeanDefinitionDslTests {
}
}
val context = GenericApplicationContext().apply {
val context = GenericApplicationContext().apply {
environment.propertySources.addFirst(SimpleCommandLinePropertySource("--name=foofoo"))
beans.initialize(this)
refresh()
@ -103,6 +105,43 @@ class BeanDefinitionDslTests { @@ -103,6 +105,43 @@ class BeanDefinitionDslTests {
val foofoo = context.getBean<FooFoo>()
assertEquals("foofoo", foofoo.name)
}
@Test // SPR-16412
fun `Declare beans depending on environment properties`() {
val beans = beans {
val n = env["number-of-beans"].toInt()
for (i in 1..n) {
bean("string$i") { Foo() }
}
}
val context = GenericApplicationContext().apply {
environment.propertySources.addLast(MockPropertySource().withProperty("number-of-beans", 5))
beans.initialize(this)
refresh()
}
for (i in 1..5) {
assertNotNull(context.getBean("string$i"))
}
}
@Test // SPR-16269
fun `Provide access to the context for allowing calling advanced features like getBeansOfType`() {
val beans = beans {
bean<Foo>("foo1")
bean<Foo>("foo2")
bean { BarBar(context.getBeansOfType<Foo>().values) }
}
val context = GenericApplicationContext().apply {
beans.initialize(this)
refresh()
}
val barbar = context.getBean<BarBar>()
assertEquals(2, barbar.foos.size)
}
}
@ -110,3 +149,4 @@ class Foo @@ -110,3 +149,4 @@ class Foo
class Bar
class Baz(val bar: Bar)
class FooFoo(val name: String)
class BarBar(val foos: Collection<Foo>)

Loading…
Cancel
Save