diff --git a/framework-platform/framework-platform.gradle b/framework-platform/framework-platform.gradle index 0d5816e404..2c948581b7 100644 --- a/framework-platform/framework-platform.gradle +++ b/framework-platform/framework-platform.gradle @@ -46,6 +46,7 @@ dependencies { api("commons-io:commons-io:2.11.0") api("de.bechte.junit:junit-hierarchicalcontextrunner:4.12.1") api("info.picocli:picocli:4.6.3") + api("io.micrometer:context-propagation:1.0.0-M5") api("io.mockk:mockk:1.12.1") api("io.projectreactor.tools:blockhound:1.0.6.RELEASE") api("io.r2dbc:r2dbc-h2:1.0.0.RC1") diff --git a/spring-test/spring-test.gradle b/spring-test/spring-test.gradle index 87bdd7cba6..320820dd8b 100644 --- a/spring-test/spring-test.gradle +++ b/spring-test/spring-test.gradle @@ -56,6 +56,7 @@ dependencies { testImplementation(testFixtures(project(":spring-core"))) testImplementation(testFixtures(project(":spring-tx"))) testImplementation(testFixtures(project(":spring-web"))) + testImplementation('io.micrometer:context-propagation') testImplementation("jakarta.annotation:jakarta.annotation-api") testImplementation("javax.cache:cache-api") testImplementation("jakarta.ejb:jakarta.ejb-api") diff --git a/spring-webmvc/spring-webmvc.gradle b/spring-webmvc/spring-webmvc.gradle index 39f712ab2a..9e36938f2b 100644 --- a/spring-webmvc/spring-webmvc.gradle +++ b/spring-webmvc/spring-webmvc.gradle @@ -19,6 +19,7 @@ dependencies { optional("jakarta.servlet.jsp.jstl:jakarta.servlet.jsp.jstl-api") optional("jakarta.el:jakarta.el-api") optional("jakarta.xml.bind:jakarta.xml.bind-api") + optional('io.micrometer:context-propagation') optional("org.webjars:webjars-locator-core") optional("com.rometools:rome") optional("com.github.librepdf:openpdf") diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java index 107fc4cd7b..e23fa38ca6 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,11 +26,14 @@ import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import io.micrometer.context.ContextSnapshot; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import org.springframework.core.MethodParameter; import org.springframework.core.ReactiveAdapter; @@ -126,8 +129,18 @@ class ReactiveTypeHandler { ModelAndViewContainer mav, NativeWebRequest request) throws Exception { Assert.notNull(returnValue, "Expected return value"); - ReactiveAdapter adapter = this.adapterRegistry.getAdapter(returnValue.getClass()); - Assert.state(adapter != null, () -> "Unexpected return value: " + returnValue); + Class clazz = returnValue.getClass(); + ReactiveAdapter adapter = this.adapterRegistry.getAdapter(clazz); + Assert.state(adapter != null, () -> "Unexpected return value type: " + clazz); + + if (Mono.class.isAssignableFrom(clazz)) { + ContextSnapshot snapshot = ContextSnapshot.captureAll(); + returnValue = ((Mono) returnValue).contextWrite(snapshot::updateContext); + } + else if (Flux.class.isAssignableFrom(clazz)) { + ContextSnapshot snapshot = ContextSnapshot.captureAll(); + returnValue = ((Flux) returnValue).contextWrite(snapshot::updateContext); + } ResolvableType elementType = ResolvableType.forMethodParameter(returnType).getGeneric(); Class elementClass = elementType.toClass(); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandlerTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandlerTests.java index d3415b515a..83b4fe44ec 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandlerTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandlerTests.java @@ -18,13 +18,19 @@ package org.springframework.web.servlet.mvc.method.annotation; import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import io.micrometer.context.ContextRegistry; +import io.micrometer.context.ContextSnapshot; +import io.micrometer.context.ContextSnapshot.Scope; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Schedulers; import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; @@ -239,6 +245,46 @@ public class ResponseBodyEmitterReturnValueHandlerTests { assertThat(this.response.getContentAsString()).isEqualTo("data:foo\n\ndata:bar\n\ndata:baz\n\n"); } + @SuppressWarnings({"try","unused"}) + @Test + public void responseBodyFluxWithThreadLocal() throws Exception { + + this.request.addHeader("Accept", "text/event-stream"); + + ThreadLocal threadLocal = new ThreadLocal<>(); + ContextRegistry.getInstance().registerThreadLocalAccessor("key", threadLocal); + + CountDownLatch latch = new CountDownLatch(1); + + Flux flux = Flux.just("foo", "bar", "baz") + .publishOn(Schedulers.boundedElastic()) + .transformDeferredContextual((theFlux, contextView) -> + theFlux.map(s -> { + try (Scope scope = ContextSnapshot.setThreadLocalsFrom(contextView, "key")) { + return s + threadLocal.get(); + } + })) + .doOnTerminate(latch::countDown); + + try { + threadLocal.set("123"); + this.handler.handleReturnValue(flux, + on(TestController.class).resolveReturnType(Flux.class, String.class), + this.mavContainer, this.webRequest); + } + finally { + threadLocal.remove(); + } + + latch.await(5, TimeUnit.SECONDS); + + assertThat(this.request.isAsyncStarted()).isTrue(); + assertThat(this.response.getStatus()).isEqualTo(200); + + assertThat(this.response.getContentType()).isEqualTo("text/event-stream"); + assertThat(this.response.getContentAsString()).isEqualTo("data:foo123\n\ndata:bar123\n\ndata:baz123\n\n"); + } + @Test // gh-21972 public void responseBodyFluxWithError() throws Exception { diff --git a/src/docs/asciidoc/web/webmvc.adoc b/src/docs/asciidoc/web/webmvc.adoc index c5e2cca7f0..304133d5e2 100644 --- a/src/docs/asciidoc/web/webmvc.adoc +++ b/src/docs/asciidoc/web/webmvc.adoc @@ -4756,6 +4756,46 @@ suitable under load. If you plan to stream with a reactive type, you should use +[[mvc-ann-async-context-propagation]] +=== Context Propagation + +It is common to propagate context via `java.lang.ThreadLocal`. This works transparently +for handling on the same thread, but requires additional work for asynchronous handling +across multiple threads. The Micrometer +https://github.com/micrometer-metrics/context-propagation#context-propagation-library[Context Propagation] +library simplifies context propagation across threads, and across context mechanisms such +as `ThreadLocal` values, +Reactor https://projectreactor.io/docs/core/release/reference/#context[context], +GraphQL Java https://www.graphql-java.com/documentation/concerns/#context-objects[context], +and others. + +If Micrometer Context Propagation is present on the classpath, when a controller method +returns a <> such as `Flux` or `Mono`, all +`ThreadLocal` values, for which there is a registered `io.micrometer.ThreadLocalAccessor`, +are written to the Reactor `Context` as key-value pairs, using the key assigned by the +`ThreadLocalAccessor`. + +For other asynchronous handling scenarios, you can use the Context Propagation library +directly. For example: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + // Capture ThreadLocal values from the main thread ... + ContextSnapshot snapshot = ContextSnapshot.captureAll(); + + // On a different thread: restore ThreadLocal values + try (ContextSnapshot.Scope scoped = snapshot.setThreadLocals()) { + // ... + } +---- + +For more details, see the +https://micrometer.io/docs/contextPropagation[documentation] of the Micrometer Context +Propagation library. + + + [[mvc-ann-async-disconnects]] === Disconnects [.small]#<>#