Browse Source

Add context function to CoRouterFunctionDsl

This new function allows to customize the CoroutineContext
potentially dynamically based on the incoming
ServerRequest.

Closes gh-27010
pull/31132/head
Sébastien Deleuze 1 year ago
parent
commit
38392233ba
  1. 48
      spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt
  2. 68
      spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt

48
spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt

@ -21,6 +21,7 @@ import kotlinx.coroutines.Job @@ -21,6 +21,7 @@ import kotlinx.coroutines.Job
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.reactor.awaitSingle
import kotlinx.coroutines.reactor.mono
import kotlinx.coroutines.withContext
import org.springframework.core.io.Resource
import org.springframework.http.HttpMethod
import org.springframework.http.HttpStatusCode
@ -72,6 +73,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @@ -72,6 +73,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
@PublishedApi
internal val builder = RouterFunctions.route()
private var contextProvider: (suspend (ServerRequest) -> CoroutineContext)? = null
/**
* Return a composed request predicate that tests against both this predicate AND
* the [other] predicate (String processed as a path predicate). When evaluating the
@ -510,9 +513,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @@ -510,9 +513,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
*/
fun resources(lookupFunction: suspend (ServerRequest) -> Resource?) {
builder.resources {
mono(Dispatchers.Unconfined) {
lookupFunction.invoke(it)
}
asMono(it, handler = lookupFunction)
}
}
@ -534,7 +535,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @@ -534,7 +535,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
*/
fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) {
builder.filter { serverRequest, handlerFunction ->
mono(Dispatchers.Unconfined) {
asMono(serverRequest) {
filterFunction(serverRequest) { handlerRequest ->
if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) {
handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle()
@ -578,7 +579,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @@ -578,7 +579,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
*/
fun onError(predicate: (Throwable) -> Boolean, responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
builder.onError(predicate) { throwable, request ->
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
asMono(request) { responseProvider.invoke(throwable, request) }
}
}
@ -591,7 +592,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @@ -591,7 +592,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
*/
inline fun <reified E : Throwable> onError(noinline responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
builder.onError({it is E}) { throwable, request ->
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
asMono(request) { responseProvider.invoke(throwable, request) }
}
}
@ -619,6 +620,19 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @@ -619,6 +620,19 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
builder.withAttributes(attributesConsumer)
}
/**
* Allow to provide the default [CoroutineContext], potentially dynamically based on
* the incoming [ServerRequest].
* @param provider the [CoroutineContext] provider
* @since 6.1.0
*/
fun context(provider: suspend (ServerRequest) -> CoroutineContext) {
if (this.contextProvider != null) {
throw IllegalStateException("The Coroutine context provider should be defined not more than once")
}
this.contextProvider = provider
}
/**
* Return a composed routing function created from all the registered routes.
*/
@ -627,8 +641,22 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @@ -627,8 +641,22 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
return builder.build()
}
private fun <T : ServerResponse> asHandlerFunction(handler: suspend (ServerRequest) -> T) =
CoroutineContextAwareHandlerFunction(handler)
@PublishedApi
internal fun <T> asMono(request: ServerRequest, context: CoroutineContext = Dispatchers.Unconfined, handler: suspend (ServerRequest) -> T): Mono<T> {
return mono(context) {
contextProvider?.let {
withContext(it.invoke(request)) {
handler.invoke(request)
}
} ?: run {
handler.invoke(request)
}
}
}
private fun asHandlerFunction(handler: suspend (ServerRequest) -> ServerResponse) = CoroutineContextAwareHandlerFunction { request ->
handler.invoke(request)
}
/**
* @see ServerResponse.from
@ -698,7 +726,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @@ -698,7 +726,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
fun status(status: Int) = ServerResponse.status(status)
private class CoroutineContextAwareHandlerFunction<T : ServerResponse>(
private inner class CoroutineContextAwareHandlerFunction<T : ServerResponse>(
private val handler: suspend (ServerRequest) -> T
) : HandlerFunction<T> {
@ -706,7 +734,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @@ -706,7 +734,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
return handle(Dispatchers.Unconfined, request)
}
fun handle(context: CoroutineContext, request: ServerRequest) = mono(context) {
fun handle(context: CoroutineContext, request: ServerRequest) = asMono(request, context) {
handler(request)
}

68
spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt

@ -16,11 +16,8 @@ @@ -16,11 +16,8 @@
package org.springframework.web.reactive.function.server
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import kotlinx.coroutines.*
import org.assertj.core.api.Assertions.*
import org.junit.jupiter.api.Test
import org.springframework.core.io.ClassPathResource
import org.springframework.http.HttpHeaders.ACCEPT
@ -179,6 +176,48 @@ class CoRouterFunctionDslTests { @@ -179,6 +176,48 @@ class CoRouterFunctionDslTests {
.verifyComplete()
}
@Test
fun contextProvider() {
val mockRequest = get("https://example.com/")
.header("Custom-Header", "foo")
.build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(routerWithContextProvider.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.contains("foo")
}
.verifyComplete()
}
@Test
fun contextProviderAndFilter() {
val mockRequest = get("https://example.com/")
.header("Custom-Header", "bar")
.build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(routerWithContextProvider.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.let {
it.contains("bar") && it.contains("Dispatchers.Default")
}
}
.verifyComplete()
}
@Test
fun multipleContextProviders() {
assertThatIllegalStateException().isThrownBy {
coRouter {
context {
CoroutineName("foo")
}
context {
Dispatchers.Default
}
}
}
}
@Test
fun attributes() {
val visitor = AttributesTestVisitor()
@ -251,6 +290,25 @@ class CoRouterFunctionDslTests { @@ -251,6 +290,25 @@ class CoRouterFunctionDslTests {
}
}
private val routerWithContextProvider = coRouter {
context {
CoroutineName(it.headers().firstHeader("Custom-Header")!!)
}
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
filter { request, next ->
if (request.headers().firstHeader("Custom-Header") == "bar") {
withContext(currentCoroutineContext() + Dispatchers.Default) {
next.invoke(request)
}
}
else {
next.invoke(request)
}
}
}
private val otherRouter = router {
"/other" {
ok().build()

Loading…
Cancel
Save