Browse Source

Support context propagation for Spring MVC controllers

Closes gh-29056
pull/29282/head
rstoyanchev 2 years ago
parent
commit
b6c2e8de23
  1. 1
      framework-platform/framework-platform.gradle
  2. 1
      spring-test/spring-test.gradle
  3. 1
      spring-webmvc/spring-webmvc.gradle
  4. 19
      spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java
  5. 46
      spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandlerTests.java
  6. 40
      src/docs/asciidoc/web/webmvc.adoc

1
framework-platform/framework-platform.gradle

@ -46,6 +46,7 @@ dependencies { @@ -46,6 +46,7 @@ dependencies {
api("commons-io:commons-io:2.11.0")
api("de.bechte.junit:junit-hierarchicalcontextrunner:4.12.1")
api("info.picocli:picocli:4.6.3")
api("io.micrometer:context-propagation:1.0.0-M5")
api("io.mockk:mockk:1.12.1")
api("io.projectreactor.tools:blockhound:1.0.6.RELEASE")
api("io.r2dbc:r2dbc-h2:1.0.0.RC1")

1
spring-test/spring-test.gradle

@ -56,6 +56,7 @@ dependencies { @@ -56,6 +56,7 @@ dependencies {
testImplementation(testFixtures(project(":spring-core")))
testImplementation(testFixtures(project(":spring-tx")))
testImplementation(testFixtures(project(":spring-web")))
testImplementation('io.micrometer:context-propagation')
testImplementation("jakarta.annotation:jakarta.annotation-api")
testImplementation("javax.cache:cache-api")
testImplementation("jakarta.ejb:jakarta.ejb-api")

1
spring-webmvc/spring-webmvc.gradle

@ -19,6 +19,7 @@ dependencies { @@ -19,6 +19,7 @@ dependencies {
optional("jakarta.servlet.jsp.jstl:jakarta.servlet.jsp.jstl-api")
optional("jakarta.el:jakarta.el-api")
optional("jakarta.xml.bind:jakarta.xml.bind-api")
optional('io.micrometer:context-propagation')
optional("org.webjars:webjars-locator-core")
optional("com.rometools:rome")
optional("com.github.librepdf:openpdf")

19
spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,11 +26,14 @@ import java.util.Optional; @@ -26,11 +26,14 @@ import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import io.micrometer.context.ContextSnapshot;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapter;
@ -126,8 +129,18 @@ class ReactiveTypeHandler { @@ -126,8 +129,18 @@ class ReactiveTypeHandler {
ModelAndViewContainer mav, NativeWebRequest request) throws Exception {
Assert.notNull(returnValue, "Expected return value");
ReactiveAdapter adapter = this.adapterRegistry.getAdapter(returnValue.getClass());
Assert.state(adapter != null, () -> "Unexpected return value: " + returnValue);
Class<?> clazz = returnValue.getClass();
ReactiveAdapter adapter = this.adapterRegistry.getAdapter(clazz);
Assert.state(adapter != null, () -> "Unexpected return value type: " + clazz);
if (Mono.class.isAssignableFrom(clazz)) {
ContextSnapshot snapshot = ContextSnapshot.captureAll();
returnValue = ((Mono<?>) returnValue).contextWrite(snapshot::updateContext);
}
else if (Flux.class.isAssignableFrom(clazz)) {
ContextSnapshot snapshot = ContextSnapshot.captureAll();
returnValue = ((Flux<?>) returnValue).contextWrite(snapshot::updateContext);
}
ResolvableType elementType = ResolvableType.forMethodParameter(returnType).getGeneric();
Class<?> elementClass = elementType.toClass();

46
spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandlerTests.java

@ -18,13 +18,19 @@ package org.springframework.web.servlet.mvc.method.annotation; @@ -18,13 +18,19 @@ package org.springframework.web.servlet.mvc.method.annotation;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import io.micrometer.context.ContextRegistry;
import io.micrometer.context.ContextSnapshot;
import io.micrometer.context.ContextSnapshot.Scope;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;
import reactor.core.scheduler.Schedulers;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
@ -239,6 +245,46 @@ public class ResponseBodyEmitterReturnValueHandlerTests { @@ -239,6 +245,46 @@ public class ResponseBodyEmitterReturnValueHandlerTests {
assertThat(this.response.getContentAsString()).isEqualTo("data:foo\n\ndata:bar\n\ndata:baz\n\n");
}
@SuppressWarnings({"try","unused"})
@Test
public void responseBodyFluxWithThreadLocal() throws Exception {
this.request.addHeader("Accept", "text/event-stream");
ThreadLocal<String> threadLocal = new ThreadLocal<>();
ContextRegistry.getInstance().registerThreadLocalAccessor("key", threadLocal);
CountDownLatch latch = new CountDownLatch(1);
Flux<String> flux = Flux.just("foo", "bar", "baz")
.publishOn(Schedulers.boundedElastic())
.transformDeferredContextual((theFlux, contextView) ->
theFlux.map(s -> {
try (Scope scope = ContextSnapshot.setThreadLocalsFrom(contextView, "key")) {
return s + threadLocal.get();
}
}))
.doOnTerminate(latch::countDown);
try {
threadLocal.set("123");
this.handler.handleReturnValue(flux,
on(TestController.class).resolveReturnType(Flux.class, String.class),
this.mavContainer, this.webRequest);
}
finally {
threadLocal.remove();
}
latch.await(5, TimeUnit.SECONDS);
assertThat(this.request.isAsyncStarted()).isTrue();
assertThat(this.response.getStatus()).isEqualTo(200);
assertThat(this.response.getContentType()).isEqualTo("text/event-stream");
assertThat(this.response.getContentAsString()).isEqualTo("data:foo123\n\ndata:bar123\n\ndata:baz123\n\n");
}
@Test // gh-21972
public void responseBodyFluxWithError() throws Exception {

40
src/docs/asciidoc/web/webmvc.adoc

@ -4756,6 +4756,46 @@ suitable under load. If you plan to stream with a reactive type, you should use @@ -4756,6 +4756,46 @@ suitable under load. If you plan to stream with a reactive type, you should use
[[mvc-ann-async-context-propagation]]
=== Context Propagation
It is common to propagate context via `java.lang.ThreadLocal`. This works transparently
for handling on the same thread, but requires additional work for asynchronous handling
across multiple threads. The Micrometer
https://github.com/micrometer-metrics/context-propagation#context-propagation-library[Context Propagation]
library simplifies context propagation across threads, and across context mechanisms such
as `ThreadLocal` values,
Reactor https://projectreactor.io/docs/core/release/reference/#context[context],
GraphQL Java https://www.graphql-java.com/documentation/concerns/#context-objects[context],
and others.
If Micrometer Context Propagation is present on the classpath, when a controller method
returns a <<mvc-ann-async-reactive-types,reactive type>> such as `Flux` or `Mono`, all
`ThreadLocal` values, for which there is a registered `io.micrometer.ThreadLocalAccessor`,
are written to the Reactor `Context` as key-value pairs, using the key assigned by the
`ThreadLocalAccessor`.
For other asynchronous handling scenarios, you can use the Context Propagation library
directly. For example:
[source,java,indent=0,subs="verbatim,quotes",role="primary"]
.Java
----
// Capture ThreadLocal values from the main thread ...
ContextSnapshot snapshot = ContextSnapshot.captureAll();
// On a different thread: restore ThreadLocal values
try (ContextSnapshot.Scope scoped = snapshot.setThreadLocals()) {
// ...
}
----
For more details, see the
https://micrometer.io/docs/contextPropagation[documentation] of the Micrometer Context
Propagation library.
[[mvc-ann-async-disconnects]]
=== Disconnects
[.small]#<<web-reactive.adoc#webflux-codecs-streaming, WebFlux>>#

Loading…
Cancel
Save