Browse Source

Support @RestControllerAdvice in Standalone MockMvc again

Since Spring Framework 5.2, @RestControllerAdvice registered with
MockMvc when using MockMvcBuilders.standaloneSetup() has no longer been
properly supported if annotation attributes were declared in the
@RestControllerAdvice annotation. Prior to 5.2, this was not an issue.

The cause for this regression is two-fold.

1. Commit 50c257794f refactored
   DefaultListableBeanFactory so that findAnnotationOnBean() supports
   merged annotations; however, that commit did not refactor
   StaticListableBeanFactory#findAnnotationOnBean() to support merged
   annotations.

2. Commit 978adbdae7 refactored
   ControllerAdviceBean so that a merged @ControllerAdvice annotation
   is only looked up via ApplicationContext#findAnnotationOnBean().

The latter relies on the fact that findAnnotationOnBean() supports
merged annotations (e.g., @RestControllerAdvice as a merged instance of
@ControllerAdvice). Behind the scenes, MockMvcBuilders.standaloneSetup()
creates a StubWebApplicationContext which internally uses a
StubBeanFactory which extends StaticListableBeanFactory. Consequently,
since the implementation of findAnnotationOnBean() in
StaticListableBeanFactory was not updated to support merged annotations
like it was in DefaultListableBeanFactory, we only see this regression
with the standalone MockMvc support and not with MockMvc support for an
existing WebApplicationContext or with standard Spring applications
using an ApplicationContext that uses DefaultListableBeanFactory.

This commit fixes this regression by supporting merged annotations in
StaticListableBeanFactory#findAnnotationOnBean() as well.

Closes gh-25520
pull/25798/head
Sam Brannen 5 years ago
parent
commit
96da1ff9ea
  1. 4
      spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java
  2. 59
      spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java
  3. 156
      spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java

4
spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java

@ -37,7 +37,7 @@ import org.springframework.beans.factory.ObjectProvider; @@ -37,7 +37,7 @@ import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.SmartFactoryBean;
import org.springframework.core.OrderComparator;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
@ -450,7 +450,7 @@ public class StaticListableBeanFactory implements ListableBeanFactory { @@ -450,7 +450,7 @@ public class StaticListableBeanFactory implements ListableBeanFactory {
throws NoSuchBeanDefinitionException {
Class<?> beanType = getType(beanName);
return (beanType != null ? AnnotationUtils.findAnnotation(beanType, annotationType) : null);
return (beanType != null ? AnnotatedElementUtils.findMergedAnnotation(beanType, annotationType) : null);
}
}

59
spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java

@ -16,6 +16,8 @@ @@ -16,6 +16,8 @@
package org.springframework.beans.factory;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
@ -33,6 +35,7 @@ import org.springframework.beans.testfixture.beans.TestAnnotation; @@ -33,6 +35,7 @@ import org.springframework.beans.testfixture.beans.TestAnnotation;
import org.springframework.beans.testfixture.beans.TestBean;
import org.springframework.beans.testfixture.beans.factory.DummyFactory;
import org.springframework.cglib.proxy.NoOp;
import org.springframework.core.annotation.AliasFor;
import org.springframework.core.io.Resource;
import org.springframework.util.ObjectUtils;
@ -324,6 +327,33 @@ public class BeanFactoryUtilsTests { @@ -324,6 +327,33 @@ public class BeanFactoryUtilsTests {
assertThat(Arrays.equals(new String[] { "buffer" }, deps)).isTrue();
}
@Test
public void findAnnotationOnBean() {
this.listableBeanFactory.registerSingleton("controllerAdvice", new ControllerAdviceClass());
this.listableBeanFactory.registerSingleton("restControllerAdvice", new RestControllerAdviceClass());
testFindAnnotationOnBean(this.listableBeanFactory);
}
@Test // gh-25520
public void findAnnotationOnBeanWithStaticFactory() {
StaticListableBeanFactory lbf = new StaticListableBeanFactory();
lbf.addBean("controllerAdvice", new ControllerAdviceClass());
lbf.addBean("restControllerAdvice", new RestControllerAdviceClass());
testFindAnnotationOnBean(lbf);
}
private void testFindAnnotationOnBean(ListableBeanFactory lbf) {
assertControllerAdvice(lbf, "controllerAdvice");
assertControllerAdvice(lbf, "restControllerAdvice");
}
private void assertControllerAdvice(ListableBeanFactory lbf, String beanName) {
ControllerAdvice controllerAdvice = lbf.findAnnotationOnBean(beanName, ControllerAdvice.class);
assertThat(controllerAdvice).isNotNull();
assertThat(controllerAdvice.value()).isEqualTo("com.example");
assertThat(controllerAdvice.basePackage()).isEqualTo("com.example");
}
@Test
public void isSingletonAndIsPrototypeWithStaticFactory() {
StaticListableBeanFactory lbf = new StaticListableBeanFactory();
@ -393,6 +423,35 @@ public class BeanFactoryUtilsTests { @@ -393,6 +423,35 @@ public class BeanFactoryUtilsTests {
}
@Retention(RetentionPolicy.RUNTIME)
@interface ControllerAdvice {
@AliasFor("basePackage")
String value() default "";
@AliasFor("value")
String basePackage() default "";
}
@Retention(RetentionPolicy.RUNTIME)
@ControllerAdvice
@interface RestControllerAdvice {
@AliasFor(annotation = ControllerAdvice.class)
String value() default "";
@AliasFor(annotation = ControllerAdvice.class)
String basePackage() default "";
}
@ControllerAdvice("com.example")
static class ControllerAdviceClass {
}
@RestControllerAdvice("com.example")
static class RestControllerAdviceClass {
}
static class TestBeanSmartFactoryBean implements SmartFactoryBean<TestBean> {
private final TestBean testBean = new TestBean("enigma", 42);

156
spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java

@ -16,16 +16,23 @@ @@ -16,16 +16,23 @@
package org.springframework.test.web.servlet.samples.standalone;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.RestControllerAdvice;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.forwardedUrl;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standaloneSetup;
@ -37,28 +44,32 @@ import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standal @@ -37,28 +44,32 @@ import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standal
*/
class ExceptionHandlerTests {
@Test
void localExceptionHandlerMethod() throws Exception {
standaloneSetup(new PersonController()).build()
.perform(get("/person/Clyde"))
@Nested
class MvcTests {
@Test
void localExceptionHandlerMethod() throws Exception {
standaloneSetup(new PersonController()).build()
.perform(get("/person/Clyde"))
.andExpect(status().isOk())
.andExpect(forwardedUrl("errorView"));
}
}
@Test
void globalExceptionHandlerMethod() throws Exception {
standaloneSetup(new PersonController()).setControllerAdvice(new GlobalExceptionHandler()).build()
@Test
void globalExceptionHandlerMethod() throws Exception {
standaloneSetup(new PersonController()).setControllerAdvice(new GlobalExceptionHandler()).build()
.perform(get("/person/Bonnie"))
.andExpect(status().isOk())
.andExpect(forwardedUrl("globalErrorView"));
}
}
@Test
void globalExceptionHandlerMethodUsingClassArgument() throws Exception {
standaloneSetup(PersonController.class).setControllerAdvice(GlobalExceptionHandler.class).build()
@Test
void globalExceptionHandlerMethodUsingClassArgument() throws Exception {
standaloneSetup(PersonController.class).setControllerAdvice(GlobalExceptionHandler.class).build()
.perform(get("/person/Bonnie"))
.andExpect(status().isOk())
.andExpect(forwardedUrl("globalErrorView"));
}
}
@ -82,7 +93,6 @@ class ExceptionHandlerTests { @@ -82,7 +93,6 @@ class ExceptionHandlerTests {
}
}
@ControllerAdvice
private static class GlobalExceptionHandler {
@ -92,4 +102,124 @@ class ExceptionHandlerTests { @@ -92,4 +102,124 @@ class ExceptionHandlerTests {
}
}
@Nested
class RestTests {
@Test
void noException() throws Exception {
standaloneSetup(RestPersonController.class)
.setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build()
.perform(get("/person/Yoda").accept(MediaType.APPLICATION_JSON))
.andExpect(status().isOk())
.andExpect(jsonPath("$.name").value("Yoda"));
}
@Test
void localExceptionHandlerMethod() throws Exception {
standaloneSetup(RestPersonController.class)
.setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build()
.perform(get("/person/Luke").accept(MediaType.APPLICATION_JSON))
.andExpect(status().isOk())
.andExpect(jsonPath("$.error").value("local - IllegalArgumentException"));
}
@Test
void globalExceptionHandlerMethod() throws Exception {
standaloneSetup(RestPersonController.class)
.setControllerAdvice(RestGlobalExceptionHandler.class).build()
.perform(get("/person/Leia").accept(MediaType.APPLICATION_JSON))
.andExpect(status().isOk())
.andExpect(jsonPath("$.error").value("global - IllegalStateException"));
}
@Test
void globalRestPersonControllerExceptionHandlerTakesPrecedenceOverGlobalExceptionHandler() throws Exception {
standaloneSetup(RestPersonController.class)
.setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build()
.perform(get("/person/Leia").accept(MediaType.APPLICATION_JSON))
.andExpect(status().isOk())
.andExpect(jsonPath("$.error").value("globalPersonController - IllegalStateException"));
}
@Test // gh-25520
void noHandlerFound() throws Exception {
standaloneSetup(RestPersonController.class)
.setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class)
.addDispatcherServletCustomizer(dispatcherServlet -> dispatcherServlet.setThrowExceptionIfNoHandlerFound(true))
.build()
.perform(get("/bogus").accept(MediaType.APPLICATION_JSON))
.andExpect(status().isOk())
.andExpect(jsonPath("$.error").value("global - NoHandlerFoundException"));
}
}
@RestController
private static class RestPersonController {
@GetMapping("/person/{name}")
Person get(@PathVariable String name) {
switch (name) {
case "Luke":
throw new IllegalArgumentException();
case "Leia":
throw new IllegalStateException();
default:
return new Person("Yoda");
}
}
@ExceptionHandler
Error handleException(IllegalArgumentException exception) {
return new Error("local - " + exception.getClass().getSimpleName());
}
}
@RestControllerAdvice(assignableTypes = RestPersonController.class)
@Order(Ordered.HIGHEST_PRECEDENCE)
private static class RestPersonControllerExceptionHandler {
@ExceptionHandler
Error handleException(Throwable exception) {
return new Error("globalPersonController - " + exception.getClass().getSimpleName());
}
}
@RestControllerAdvice
@Order(Ordered.LOWEST_PRECEDENCE)
private static class RestGlobalExceptionHandler {
@ExceptionHandler
Error handleException(Throwable exception) {
return new Error( "global - " + exception.getClass().getSimpleName());
}
}
static class Person {
private final String name;
Person(String name) {
this.name = name;
}
public String getName() {
return name;
}
}
static class Error {
private final String error;
Error(String error) {
this.error = error;
}
public String getError() {
return error;
}
}
}

Loading…
Cancel
Save