Browse Source

Propagate CoroutineContext in coRouter filters

Closes gh-26977
pull/31132/head
Sébastien Deleuze 1 year ago
parent
commit
d47c7f9552
  1. 35
      spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt
  2. 33
      spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt

35
spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,8 @@ @@ -17,6 +17,8 @@
package org.springframework.web.reactive.function.server
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.reactor.awaitSingle
import kotlinx.coroutines.reactor.mono
import org.springframework.core.io.Resource
@ -24,7 +26,9 @@ import org.springframework.http.HttpMethod @@ -24,7 +26,9 @@ import org.springframework.http.HttpMethod
import org.springframework.http.HttpStatusCode
import org.springframework.http.MediaType
import org.springframework.web.reactive.function.server.RouterFunctions.nest
import reactor.core.publisher.Mono
import java.net.URI
import kotlin.coroutines.CoroutineContext
/**
* Allow to create easily a WebFlux.fn [RouterFunction] with a [Coroutines router Kotlin DSL][CoRouterFunctionDsl].
@ -532,7 +536,12 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @@ -532,7 +536,12 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
builder.filter { serverRequest, handlerFunction ->
mono(Dispatchers.Unconfined) {
filterFunction(serverRequest) { handlerRequest ->
handlerFunction.handle(handlerRequest).awaitSingle()
if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) {
handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle()
}
else {
handlerFunction.handle(handlerRequest).awaitSingle()
}
}
}
}
@ -618,11 +627,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @@ -618,11 +627,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
return builder.build()
}
private fun asHandlerFunction(init: suspend (ServerRequest) -> ServerResponse) = HandlerFunction {
mono(Dispatchers.Unconfined) {
init(it)
}
}
private fun <T : ServerResponse> asHandlerFunction(handler: suspend (ServerRequest) -> T) =
CoroutineContextAwareHandlerFunction(handler)
/**
* @see ServerResponse.from
@ -691,6 +697,21 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct @@ -691,6 +697,21 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
*/
fun status(status: Int) = ServerResponse.status(status)
private class CoroutineContextAwareHandlerFunction<T : ServerResponse>(
private val handler: suspend (ServerRequest) -> T
) : HandlerFunction<T> {
override fun handle(request: ServerRequest): Mono<T> {
return handle(Dispatchers.Unconfined, request)
}
fun handle(context: CoroutineContext, request: ServerRequest) = mono(context) {
handler(request)
}
}
}
/**

33
spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,17 +16,20 @@ @@ -16,17 +16,20 @@
package org.springframework.web.reactive.function.server
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.jupiter.api.Test
import org.springframework.core.io.ClassPathResource
import org.springframework.http.HttpHeaders.*
import org.springframework.http.HttpMethod.*
import org.springframework.http.HttpHeaders.ACCEPT
import org.springframework.http.HttpHeaders.CONTENT_TYPE
import org.springframework.http.HttpMethod.PATCH
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType.*
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.*
import org.springframework.web.testfixture.server.MockServerWebExchange
import org.springframework.web.reactive.function.server.AttributesTestVisitor
import reactor.test.StepVerifier
/**
@ -165,6 +168,17 @@ class CoRouterFunctionDslTests { @@ -165,6 +168,17 @@ class CoRouterFunctionDslTests {
.verifyComplete()
}
@Test
fun filteringWithContext() {
val mockRequest = get("https://example.com/").build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(filterRouterWithContext.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.contains("Filter context")
}
.verifyComplete()
}
@Test
fun attributes() {
val visitor = AttributesTestVisitor()
@ -226,6 +240,17 @@ class CoRouterFunctionDslTests { @@ -226,6 +240,17 @@ class CoRouterFunctionDslTests {
}
}
private val filterRouterWithContext = coRouter {
filter { request, next ->
withContext(CoroutineName("Filter context")) {
next(request)
}
}
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
}
private val otherRouter = router {
"/other" {
ok().build()

Loading…
Cancel
Save