diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java index 45e1e17974..fd20d1eec3 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/AbstractMessageWriterResultHandler.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Set; import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.core.KotlinDetector; @@ -130,7 +131,8 @@ public abstract class AbstractMessageWriterResultHandler extends HandlerResultHa if (adapter != null) { publisher = adapter.toPublisher(body); boolean isUnwrapped = KotlinDetector.isSuspendingFunction(bodyParameter.getMethod()) && - !COROUTINES_FLOW_CLASS_NAME.equals(bodyType.toClass().getName()); + !COROUTINES_FLOW_CLASS_NAME.equals(bodyType.toClass().getName()) && + !Flux.class.equals(bodyType.toClass()); ResolvableType genericType = isUnwrapped ? bodyType : bodyType.getGeneric(); elementType = getElementType(adapter, genericType); actualElementType = elementType; diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt index a1f6fa9a32..0b0acda630 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/result/method/annotation/CoroutinesIntegrationTests.kt @@ -37,6 +37,8 @@ import org.springframework.web.bind.annotation.RestController import org.springframework.web.client.HttpServerErrorException import org.springframework.web.reactive.config.EnableWebFlux import org.springframework.web.testfixture.http.server.reactive.bootstrap.HttpServer +import reactor.core.publisher.Flux +import java.time.Duration class CoroutinesIntegrationTests : AbstractRequestMappingIntegrationTests() { @@ -111,6 +113,25 @@ class CoroutinesIntegrationTests : AbstractRequestMappingIntegrationTests() { } } + @ParameterizedHttpServerTest + fun `Suspending handler method returning ResponseEntity of Flux `(httpServer: HttpServer) { + startServer(httpServer) + + val entity = performGet("/entity-flux", HttpHeaders.EMPTY, String::class.java) + assertThat(entity.statusCode).isEqualTo(HttpStatus.OK) + assertThat(entity.body).isEqualTo("01234") + } + + @ParameterizedHttpServerTest + fun `Suspending handler method returning ResponseEntity of Flow`(httpServer: HttpServer) { + startServer(httpServer) + + val entity = performGet("/entity-flow", HttpHeaders.EMPTY, String::class.java) + assertThat(entity.statusCode).isEqualTo(HttpStatus.OK) + assertThat(entity.body).isEqualTo("foobar") + } + + @Configuration @EnableWebFlux @ComponentScan(resourcePattern = "**/CoroutinesIntegrationTests*") @@ -169,6 +190,25 @@ class CoroutinesIntegrationTests : AbstractRequestMappingIntegrationTests() { throw IllegalStateException() } + @GetMapping("/entity-flux") + suspend fun entityFlux() : ResponseEntity> { + val strings = Flux.interval(Duration.ofMillis(100)).take(5) + .map { l -> l.toString() } + delay(1) + return ResponseEntity.ok().body(strings) + } + + @GetMapping("/entity-flow") + suspend fun entityFlow() : ResponseEntity> { + val strings = flow { + emit("foo") + delay(1) + emit("bar") + delay(1) + } + return ResponseEntity.ok().body(strings) + } + }