diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt index 0acb769e06..c6a16162b0 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ package org.springframework.web.reactive.function.server import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.reactor.awaitSingle import kotlinx.coroutines.reactor.mono import org.springframework.core.io.Resource @@ -24,7 +26,9 @@ import org.springframework.http.HttpMethod import org.springframework.http.HttpStatusCode import org.springframework.http.MediaType import org.springframework.web.reactive.function.server.RouterFunctions.nest +import reactor.core.publisher.Mono import java.net.URI +import kotlin.coroutines.CoroutineContext /** * Allow to create easily a WebFlux.fn [RouterFunction] with a [Coroutines router Kotlin DSL][CoRouterFunctionDsl]. @@ -532,7 +536,12 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct builder.filter { serverRequest, handlerFunction -> mono(Dispatchers.Unconfined) { filterFunction(serverRequest) { handlerRequest -> - handlerFunction.handle(handlerRequest).awaitSingle() + if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) { + handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle() + } + else { + handlerFunction.handle(handlerRequest).awaitSingle() + } } } } @@ -618,11 +627,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct return builder.build() } - private fun asHandlerFunction(init: suspend (ServerRequest) -> ServerResponse) = HandlerFunction { - mono(Dispatchers.Unconfined) { - init(it) - } - } + private fun asHandlerFunction(handler: suspend (ServerRequest) -> T) = + CoroutineContextAwareHandlerFunction(handler) /** * @see ServerResponse.from @@ -691,6 +697,21 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct */ fun status(status: Int) = ServerResponse.status(status) + + private class CoroutineContextAwareHandlerFunction( + private val handler: suspend (ServerRequest) -> T + ) : HandlerFunction { + + override fun handle(request: ServerRequest): Mono { + return handle(Dispatchers.Unconfined, request) + } + + fun handle(context: CoroutineContext, request: ServerRequest) = mono(context) { + handler(request) + } + + } + } /** diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt index 041956cefa..3b9ea7f1d8 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,17 +16,20 @@ package org.springframework.web.reactive.function.server +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.currentCoroutineContext +import kotlinx.coroutines.withContext import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.jupiter.api.Test import org.springframework.core.io.ClassPathResource -import org.springframework.http.HttpHeaders.* -import org.springframework.http.HttpMethod.* +import org.springframework.http.HttpHeaders.ACCEPT +import org.springframework.http.HttpHeaders.CONTENT_TYPE +import org.springframework.http.HttpMethod.PATCH import org.springframework.http.HttpStatus import org.springframework.http.MediaType.* import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.* import org.springframework.web.testfixture.server.MockServerWebExchange -import org.springframework.web.reactive.function.server.AttributesTestVisitor import reactor.test.StepVerifier /** @@ -165,6 +168,17 @@ class CoRouterFunctionDslTests { .verifyComplete() } + @Test + fun filteringWithContext() { + val mockRequest = get("https://example.com/").build() + val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList()) + StepVerifier.create(filterRouterWithContext.route(request).flatMap { it.handle(request) }) + .expectNextMatches { response -> + response.headers().getFirst("context")!!.contains("Filter context") + } + .verifyComplete() + } + @Test fun attributes() { val visitor = AttributesTestVisitor() @@ -226,6 +240,17 @@ class CoRouterFunctionDslTests { } } + private val filterRouterWithContext = coRouter { + filter { request, next -> + withContext(CoroutineName("Filter context")) { + next(request) + } + } + GET("/") { + ok().header("context", currentCoroutineContext().toString()).buildAndAwait() + } + } + private val otherRouter = router { "/other" { ok().build()