Browse Source

Revert errorhandler order in RouterFunctionBuilder

Prior to this commit, error handlers in the WebMvc.fn and WebFlux.fn
router function builders had to be registered in an unintuitive, reverse
order, due to the filter chain composition model used.
This commit reverses the error handler order, so that more specific
error handlers can come before generic ones.

Closes gh-25541
pull/25876/head
Arjen Poutsma 5 years ago
parent
commit
392895e256
  1. 13
      spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctionBuilder.java
  2. 23
      spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionBuilderTests.java
  3. 13
      spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctionBuilder.java
  4. 70
      spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java

13
spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctionBuilder.java

@ -23,6 +23,7 @@ import java.util.function.Consumer; @@ -23,6 +23,7 @@ import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Stream;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@ -43,6 +44,8 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { @@ -43,6 +44,8 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
private final List<HandlerFilterFunction<ServerResponse, ServerResponse>> filterFunctions = new ArrayList<>();
private final List<HandlerFilterFunction<ServerResponse, ServerResponse>> errorHandlers = new ArrayList<>();
@Override
public RouterFunctions.Builder add(RouterFunction<ServerResponse> routerFunction) {
@ -310,8 +313,9 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { @@ -310,8 +313,9 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
Assert.notNull(predicate, "Predicate must not be null");
Assert.notNull(responseProvider, "ResponseProvider must not be null");
return filter((request, next) -> next.handle(request)
this.errorHandlers.add(0, (request, next) -> next.handle(request)
.onErrorResume(predicate, t -> responseProvider.apply(t, request)));
return this;
}
@Override
@ -321,8 +325,9 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { @@ -321,8 +325,9 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
Assert.notNull(exceptionType, "ExceptionType must not be null");
Assert.notNull(responseProvider, "ResponseProvider must not be null");
return filter((request, next) -> next.handle(request)
this.errorHandlers.add(0, (request, next) -> next.handle(request)
.onErrorResume(exceptionType, t -> responseProvider.apply(t, request)));
return this;
}
@Override
@ -332,12 +337,12 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { @@ -332,12 +337,12 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
}
RouterFunction<ServerResponse> result = new BuiltRouterFunction(this.routerFunctions);
if (this.filterFunctions.isEmpty()) {
if (this.filterFunctions.isEmpty() && this.errorHandlers.isEmpty()) {
return result;
}
else {
HandlerFilterFunction<ServerResponse, ServerResponse> filter =
this.filterFunctions.stream()
Stream.concat(this.filterFunctions.stream(), this.errorHandlers.stream())
.reduce(HandlerFilterFunction::andThen)
.orElseThrow(IllegalStateException::new);

23
spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionBuilderTests.java

@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
package org.springframework.web.reactive.function.server;
import java.io.IOException;
import java.util.Collections;
import java.util.concurrent.atomic.AtomicInteger;
@ -210,4 +211,26 @@ public class RouterFunctionBuilderTests { @@ -210,4 +211,26 @@ public class RouterFunctionBuilderTests {
.verifyComplete();
}
@Test
public void multipleOnErrors() {
RouterFunction<ServerResponse> route = RouterFunctions.route()
.GET("/error", request -> Mono.error(new IOException()))
.onError(IOException.class, (t, r) -> ServerResponse.status(200).build())
.onError(Exception.class, (t, r) -> ServerResponse.status(201).build())
.build();
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com/error").build();
ServerRequest serverRequest = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
Mono<HttpStatus> responseStatus = route.route(serverRequest)
.flatMap(handlerFunction -> handlerFunction.handle(serverRequest))
.map(ServerResponse::statusCode);
StepVerifier.create(responseStatus)
.assertNext(status -> assertThat(status).isEqualTo(HttpStatus.OK))
.verifyComplete();
}
}

13
spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctionBuilder.java

@ -24,6 +24,7 @@ import java.util.function.Consumer; @@ -24,6 +24,7 @@ import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpMethod;
@ -41,6 +42,8 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { @@ -41,6 +42,8 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
private final List<HandlerFilterFunction<ServerResponse, ServerResponse>> filterFunctions = new ArrayList<>();
private final List<HandlerFilterFunction<ServerResponse, ServerResponse>> errorHandlers = new ArrayList<>();
@Override
public RouterFunctions.Builder add(RouterFunction<ServerResponse> routerFunction) {
@ -307,7 +310,8 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { @@ -307,7 +310,8 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
Assert.notNull(predicate, "Predicate must not be null");
Assert.notNull(responseProvider, "ResponseProvider must not be null");
return filter(HandlerFilterFunction.ofErrorHandler(predicate, responseProvider));
this.errorHandlers.add(0, HandlerFilterFunction.ofErrorHandler(predicate, responseProvider));
return this;
}
@Override
@ -316,8 +320,7 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { @@ -316,8 +320,7 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
Assert.notNull(exceptionType, "ExceptionType must not be null");
Assert.notNull(responseProvider, "ResponseProvider must not be null");
return filter(HandlerFilterFunction.ofErrorHandler(exceptionType::isInstance,
responseProvider));
return onError(exceptionType::isInstance, responseProvider);
}
@Override
@ -327,12 +330,12 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { @@ -327,12 +330,12 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
}
RouterFunction<ServerResponse> result = new BuiltRouterFunction(this.routerFunctions);
if (this.filterFunctions.isEmpty()) {
if (this.filterFunctions.isEmpty() && this.errorHandlers.isEmpty()) {
return result;
}
else {
HandlerFilterFunction<ServerResponse, ServerResponse> filter =
this.filterFunctions.stream()
Stream.concat(this.filterFunctions.stream(), this.errorHandlers.stream())
.reduce(HandlerFilterFunction::andThen)
.orElseThrow(IllegalStateException::new);

70
spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java

@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
package org.springframework.web.servlet.function;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
@ -52,36 +53,32 @@ class RouterFunctionBuilderTests { @@ -52,36 +53,32 @@ class RouterFunctionBuilderTests {
ServerRequest getFooRequest = initRequest("GET", "/foo");
Optional<Integer> responseStatus = route.route(getFooRequest)
Optional<HttpStatus> responseStatus = route.route(getFooRequest)
.map(handlerFunction -> handle(handlerFunction, getFooRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.get().intValue()).isEqualTo(200);
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.OK);
ServerRequest headFooRequest = initRequest("HEAD", "/foo");
responseStatus = route.route(headFooRequest)
.map(handlerFunction -> handle(handlerFunction, getFooRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.get().intValue()).isEqualTo(202);
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.ACCEPTED);
ServerRequest barRequest = initRequest("POST", "/", req -> req.setContentType("text/plain"));
responseStatus = route.route(barRequest)
.map(handlerFunction -> handle(handlerFunction, barRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.get().intValue()).isEqualTo(204);
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.NO_CONTENT);
ServerRequest invalidRequest = initRequest("POST", "/");
responseStatus = route.route(invalidRequest)
.map(handlerFunction -> handle(handlerFunction, invalidRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
.map(ServerResponse::statusCode);
assertThat(responseStatus.isPresent()).isFalse();
assertThat(responseStatus).isEmpty();
}
private static ServerResponse handle(HandlerFunction<ServerResponse> handlerFunction,
@ -105,19 +102,17 @@ class RouterFunctionBuilderTests { @@ -105,19 +102,17 @@ class RouterFunctionBuilderTests {
ServerRequest resourceRequest = initRequest("GET", "/resources/response.txt");
Optional<Integer> responseStatus = route.route(resourceRequest)
Optional<HttpStatus> responseStatus = route.route(resourceRequest)
.map(handlerFunction -> handle(handlerFunction, resourceRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.get().intValue()).isEqualTo(200);
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.OK);
ServerRequest invalidRequest = initRequest("POST", "/resources/foo.txt");
responseStatus = route.route(invalidRequest)
.map(handlerFunction -> handle(handlerFunction, invalidRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.isPresent()).isFalse();
.map(ServerResponse::statusCode);
assertThat(responseStatus).isEmpty();
}
@Test
@ -132,11 +127,10 @@ class RouterFunctionBuilderTests { @@ -132,11 +127,10 @@ class RouterFunctionBuilderTests {
ServerRequest fooRequest = initRequest("GET", "/foo/bar/baz");
Optional<Integer> responseStatus = route.route(fooRequest)
Optional<HttpStatus> responseStatus = route.route(fooRequest)
.map(handlerFunction -> handle(handlerFunction, fooRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.get().intValue()).isEqualTo(200);
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.OK);
}
@Test
@ -181,13 +175,33 @@ class RouterFunctionBuilderTests { @@ -181,13 +175,33 @@ class RouterFunctionBuilderTests {
ServerRequest barRequest = initRequest("GET", "/bar");
Optional<Integer> responseStatus = route.route(barRequest)
Optional<HttpStatus> responseStatus = route.route(barRequest)
.map(handlerFunction -> handle(handlerFunction, barRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.get().intValue()).isEqualTo(500);
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.INTERNAL_SERVER_ERROR);
}
@Test
public void multipleOnErrors() {
RouterFunction<ServerResponse> route = RouterFunctions.route()
.GET("/error", request -> {
throw new IOException();
})
.onError(IOException.class, (t, r) -> ServerResponse.status(200).build())
.onError(Exception.class, (t, r) -> ServerResponse.status(201).build())
.build();
MockHttpServletRequest servletRequest = new MockHttpServletRequest("GET", "/error");
ServerRequest serverRequest = new DefaultServerRequest(servletRequest, emptyList());
Optional<HttpStatus> responseStatus = route.route(serverRequest)
.map(handlerFunction -> handle(handlerFunction, serverRequest))
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.OK);
}
private ServerRequest initRequest(String httpMethod, String requestUri) {
return initRequest(httpMethod, requestUri, null);

Loading…
Cancel
Save