Browse Source

WebMvc respects RouterFunction beans ordering

Closes gh-28595
pull/28653/head
rstoyanchev 2 years ago
parent
commit
52d0681ca1
  1. 7
      spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/RouterFunctionMapping.java
  2. 20
      spring-webflux/src/test/java/org/springframework/web/reactive/function/server/support/RouterFunctionMappingTests.java
  3. 46
      spring-webmvc/src/main/java/org/springframework/web/servlet/function/support/RouterFunctionMapping.java
  4. 66
      spring-webmvc/src/test/java/org/springframework/web/servlet/function/support/RouterFunctionMappingTests.java

7
spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/RouterFunctionMapping.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -119,12 +119,11 @@ public class RouterFunctionMapping extends AbstractHandlerMapping implements Ini @@ -119,12 +119,11 @@ public class RouterFunctionMapping extends AbstractHandlerMapping implements Ini
}
private List<RouterFunction<?>> routerFunctions() {
List<RouterFunction<?>> functions = obtainApplicationContext()
return obtainApplicationContext()
.getBeanProvider(RouterFunction.class)
.orderedStream()
.map(router -> (RouterFunction<?>)router)
.map(router -> (RouterFunction<?>) router)
.collect(Collectors.toList());
return (!CollectionUtils.isEmpty(functions) ? functions : Collections.emptyList());
}
private void logRouterFunctions(List<RouterFunction<?>> routerFunctions) {

20
spring-webflux/src/test/java/org/springframework/web/reactive/function/server/support/RouterFunctionMappingTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@ import org.junit.jupiter.api.Test; @@ -20,6 +20,7 @@ import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.function.server.HandlerFunction;
@ -72,6 +73,23 @@ public class RouterFunctionMappingTests { @@ -72,6 +73,23 @@ public class RouterFunctionMappingTests {
.verify();
}
@Test
void empty() throws Exception {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
context.refresh();
RouterFunctionMapping mapping = new RouterFunctionMapping();
mapping.setMessageReaders(this.codecConfigurer.getReaders());
mapping.setApplicationContext(context);
mapping.afterPropertiesSet();
Mono<Object> result = mapping.getHandler(createExchange("https://example.com/match"));
StepVerifier.create(result)
.expectComplete()
.verify();
}
@Test
void changeParser() throws Exception {
HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build();

46
spring-webmvc/src/main/java/org/springframework/web/servlet/function/support/RouterFunctionMapping.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,11 +19,10 @@ package org.springframework.web.servlet.function.support; @@ -19,11 +19,10 @@ package org.springframework.web.servlet.function.support;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.ApplicationContext;
import org.springframework.core.SpringProperties;
@ -135,7 +134,7 @@ public class RouterFunctionMapping extends AbstractHandlerMapping implements Ini @@ -135,7 +134,7 @@ public class RouterFunctionMapping extends AbstractHandlerMapping implements Ini
@Override
public void afterPropertiesSet() throws Exception {
if (this.routerFunction == null) {
initRouterFunction();
initRouterFunctions();
}
if (CollectionUtils.isEmpty(this.messageConverters)) {
initMessageConverters();
@ -154,20 +153,39 @@ public class RouterFunctionMapping extends AbstractHandlerMapping implements Ini @@ -154,20 +153,39 @@ public class RouterFunctionMapping extends AbstractHandlerMapping implements Ini
* Detect a all {@linkplain RouterFunction router functions} in the
* current application context.
*/
@SuppressWarnings({"rawtypes", "unchecked"})
private void initRouterFunction() {
ApplicationContext applicationContext = obtainApplicationContext();
Map<String, RouterFunction> beans =
(this.detectHandlerFunctionsInAncestorContexts ?
BeanFactoryUtils.beansOfTypeIncludingAncestors(applicationContext, RouterFunction.class) :
applicationContext.getBeansOfType(RouterFunction.class));
List<RouterFunction> routerFunctions = new ArrayList<>(beans.values());
private void initRouterFunctions() {
List<RouterFunction<?>> routerFunctions = routerFunctions();
this.routerFunction = routerFunctions.stream().reduce(RouterFunction::andOther).orElse(null);
logRouterFunctions(routerFunctions);
}
@SuppressWarnings("rawtypes")
private void logRouterFunctions(List<RouterFunction> routerFunctions) {
private List<RouterFunction<?>> routerFunctions() {
List<RouterFunction<?>> routerFunctions = new ArrayList<>();
if (this.detectHandlerFunctionsInAncestorContexts) {
detectRouterFunctionsInAncestorContexts(obtainApplicationContext(), routerFunctions);
}
obtainApplicationContext()
.getBeanProvider(RouterFunction.class)
.orderedStream()
.map(router -> (RouterFunction<?>) router)
.collect(Collectors.toCollection(() -> routerFunctions));
return routerFunctions;
}
private void detectRouterFunctionsInAncestorContexts(
ApplicationContext applicationContext, List<RouterFunction<?>> routerFunctions) {
ApplicationContext parentContext = applicationContext.getParent();
if (parentContext != null) {
detectRouterFunctionsInAncestorContexts(parentContext, routerFunctions);
parentContext.getBeanProvider(RouterFunction.class)
.orderedStream()
.map(router -> (RouterFunction<?>) router)
.collect(Collectors.toCollection(() -> routerFunctions));
}
}
private void logRouterFunctions(List<RouterFunction<?>> routerFunctions) {
if (mappingsLogger.isDebugEnabled()) {
routerFunctions.forEach(function -> mappingsLogger.debug("Mapped " + function));
}

66
spring-webmvc/src/test/java/org/springframework/web/servlet/function/support/RouterFunctionMappingTests.java

@ -1,5 +1,5 @@ @@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,7 +21,10 @@ import java.util.List; @@ -21,7 +21,10 @@ import java.util.List;
import java.util.Optional;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerMapping;
@ -41,7 +44,7 @@ import static org.assertj.core.api.Assertions.assertThat; @@ -41,7 +44,7 @@ import static org.assertj.core.api.Assertions.assertThat;
*/
class RouterFunctionMappingTests {
private List<HttpMessageConverter<?>> messageConverters = Collections.emptyList();
private final List<HttpMessageConverter<?>> messageConverters = Collections.emptyList();
@Test
void normal() throws Exception {
@ -71,6 +74,65 @@ class RouterFunctionMappingTests { @@ -71,6 +74,65 @@ class RouterFunctionMappingTests {
assertThat(result).isNull();
}
@Test
void empty() throws Exception {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
context.refresh();
RouterFunctionMapping mapping = new RouterFunctionMapping();
mapping.setMessageConverters(this.messageConverters);
mapping.setApplicationContext(context);
mapping.afterPropertiesSet();
MockHttpServletRequest request = createTestRequest("/match");
HandlerExecutionChain result = mapping.getHandler(request);
assertThat(result).isNull();
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void detectHandlerFunctionsInAncestorContexts(boolean detect) throws Exception {
HandlerFunction<ServerResponse> function1 = request -> ServerResponse.ok().build();
HandlerFunction<ServerResponse> function2 = request -> ServerResponse.ok().build();
HandlerFunction<ServerResponse> function3 = request -> ServerResponse.ok().build();
AnnotationConfigApplicationContext context1 = new AnnotationConfigApplicationContext();
context1.registerBean(RouterFunction.class, () -> RouterFunctions.route().GET("/fn1", function1).build());
context1.refresh();
AnnotationConfigApplicationContext context2 = new AnnotationConfigApplicationContext();
context2.registerBean(RouterFunction.class, () -> RouterFunctions.route().GET("/fn2", function2).build());
context2.setParent(context1);
context2.refresh();
AnnotationConfigApplicationContext context3 = new AnnotationConfigApplicationContext();
context3.registerBean(RouterFunction.class, () -> RouterFunctions.route().GET("/fn3", function3).build());
context3.setParent(context2);
context3.refresh();
RouterFunctionMapping mapping = new RouterFunctionMapping();
mapping.setDetectHandlerFunctionsInAncestorContexts(detect);
mapping.setMessageConverters(this.messageConverters);
mapping.setApplicationContext(context3);
mapping.afterPropertiesSet();
HandlerExecutionChain chain1 = mapping.getHandler(createTestRequest("/fn1"));
HandlerExecutionChain chain2 = mapping.getHandler(createTestRequest("/fn2"));
if (detect) {
assertThat(chain1).isNotNull().extracting(HandlerExecutionChain::getHandler).isSameAs(function1);
assertThat(chain2).isNotNull().extracting(HandlerExecutionChain::getHandler).isSameAs(function2);
}
else {
assertThat(chain1).isNull();
assertThat(chain2).isNull();
}
HandlerExecutionChain chain3 = mapping.getHandler(createTestRequest("/fn3"));
assertThat(chain3).isNotNull().extracting(HandlerExecutionChain::getHandler).isSameAs(function3);
}
@Test
void changeParser() throws Exception {
HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build();

Loading…
Cancel
Save