diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt index c6a16162b0..7a085f83a7 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt @@ -21,6 +21,7 @@ import kotlinx.coroutines.Job import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.reactor.awaitSingle import kotlinx.coroutines.reactor.mono +import kotlinx.coroutines.withContext import org.springframework.core.io.Resource import org.springframework.http.HttpMethod import org.springframework.http.HttpStatusCode @@ -72,6 +73,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @PublishedApi internal val builder = RouterFunctions.route() + private var contextProvider: (suspend (ServerRequest) -> CoroutineContext)? = null + /** * Return a composed request predicate that tests against both this predicate AND * the [other] predicate (String processed as a path predicate). When evaluating the @@ -510,9 +513,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct */ fun resources(lookupFunction: suspend (ServerRequest) -> Resource?) { builder.resources { - mono(Dispatchers.Unconfined) { - lookupFunction.invoke(it) - } + asMono(it, handler = lookupFunction) } } @@ -534,7 +535,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct */ fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) { builder.filter { serverRequest, handlerFunction -> - mono(Dispatchers.Unconfined) { + asMono(serverRequest) { filterFunction(serverRequest) { handlerRequest -> if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) { handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle() @@ -578,7 +579,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct */ fun onError(predicate: (Throwable) -> Boolean, responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) { builder.onError(predicate) { throwable, request -> - mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) } + asMono(request) { responseProvider.invoke(throwable, request) } } } @@ -591,7 +592,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct */ inline fun onError(noinline responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) { builder.onError({it is E}) { throwable, request -> - mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) } + asMono(request) { responseProvider.invoke(throwable, request) } } } @@ -619,6 +620,19 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct builder.withAttributes(attributesConsumer) } + /** + * Allow to provide the default [CoroutineContext], potentially dynamically based on + * the incoming [ServerRequest]. + * @param provider the [CoroutineContext] provider + * @since 6.1.0 + */ + fun context(provider: suspend (ServerRequest) -> CoroutineContext) { + if (this.contextProvider != null) { + throw IllegalStateException("The Coroutine context provider should be defined not more than once") + } + this.contextProvider = provider + } + /** * Return a composed routing function created from all the registered routes. */ @@ -627,8 +641,22 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct return builder.build() } - private fun asHandlerFunction(handler: suspend (ServerRequest) -> T) = - CoroutineContextAwareHandlerFunction(handler) + @PublishedApi + internal fun asMono(request: ServerRequest, context: CoroutineContext = Dispatchers.Unconfined, handler: suspend (ServerRequest) -> T): Mono { + return mono(context) { + contextProvider?.let { + withContext(it.invoke(request)) { + handler.invoke(request) + } + } ?: run { + handler.invoke(request) + } + } + } + + private fun asHandlerFunction(handler: suspend (ServerRequest) -> ServerResponse) = CoroutineContextAwareHandlerFunction { request -> + handler.invoke(request) + } /** * @see ServerResponse.from @@ -698,7 +726,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct fun status(status: Int) = ServerResponse.status(status) - private class CoroutineContextAwareHandlerFunction( + private inner class CoroutineContextAwareHandlerFunction( private val handler: suspend (ServerRequest) -> T ) : HandlerFunction { @@ -706,7 +734,7 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct return handle(Dispatchers.Unconfined, request) } - fun handle(context: CoroutineContext, request: ServerRequest) = mono(context) { + fun handle(context: CoroutineContext, request: ServerRequest) = asMono(request, context) { handler(request) } diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt index 3b9ea7f1d8..b9fc1b9b12 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt @@ -16,11 +16,8 @@ package org.springframework.web.reactive.function.server -import kotlinx.coroutines.CoroutineName -import kotlinx.coroutines.currentCoroutineContext -import kotlinx.coroutines.withContext -import org.assertj.core.api.Assertions.assertThat -import org.assertj.core.api.Assertions.assertThatExceptionOfType +import kotlinx.coroutines.* +import org.assertj.core.api.Assertions.* import org.junit.jupiter.api.Test import org.springframework.core.io.ClassPathResource import org.springframework.http.HttpHeaders.ACCEPT @@ -179,6 +176,48 @@ class CoRouterFunctionDslTests { .verifyComplete() } + @Test + fun contextProvider() { + val mockRequest = get("https://example.com/") + .header("Custom-Header", "foo") + .build() + val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList()) + StepVerifier.create(routerWithContextProvider.route(request).flatMap { it.handle(request) }) + .expectNextMatches { response -> + response.headers().getFirst("context")!!.contains("foo") + } + .verifyComplete() + } + + @Test + fun contextProviderAndFilter() { + val mockRequest = get("https://example.com/") + .header("Custom-Header", "bar") + .build() + val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList()) + StepVerifier.create(routerWithContextProvider.route(request).flatMap { it.handle(request) }) + .expectNextMatches { response -> + response.headers().getFirst("context")!!.let { + it.contains("bar") && it.contains("Dispatchers.Default") + } + } + .verifyComplete() + } + + @Test + fun multipleContextProviders() { + assertThatIllegalStateException().isThrownBy { + coRouter { + context { + CoroutineName("foo") + } + context { + Dispatchers.Default + } + } + } + } + @Test fun attributes() { val visitor = AttributesTestVisitor() @@ -251,6 +290,25 @@ class CoRouterFunctionDslTests { } } + private val routerWithContextProvider = coRouter { + context { + CoroutineName(it.headers().firstHeader("Custom-Header")!!) + } + GET("/") { + ok().header("context", currentCoroutineContext().toString()).buildAndAwait() + } + filter { request, next -> + if (request.headers().firstHeader("Custom-Header") == "bar") { + withContext(currentCoroutineContext() + Dispatchers.Default) { + next.invoke(request) + } + } + else { + next.invoke(request) + } + } + } + private val otherRouter = router { "/other" { ok().build()