diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctionBuilder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctionBuilder.java index 4572ca03ae..8fd0cfb616 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctionBuilder.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctionBuilder.java @@ -23,6 +23,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.Stream; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -43,6 +44,8 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { private final List> filterFunctions = new ArrayList<>(); + private final List> errorHandlers = new ArrayList<>(); + @Override public RouterFunctions.Builder add(RouterFunction routerFunction) { @@ -310,8 +313,9 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { Assert.notNull(predicate, "Predicate must not be null"); Assert.notNull(responseProvider, "ResponseProvider must not be null"); - return filter((request, next) -> next.handle(request) + this.errorHandlers.add(0, (request, next) -> next.handle(request) .onErrorResume(predicate, t -> responseProvider.apply(t, request))); + return this; } @Override @@ -321,8 +325,9 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { Assert.notNull(exceptionType, "ExceptionType must not be null"); Assert.notNull(responseProvider, "ResponseProvider must not be null"); - return filter((request, next) -> next.handle(request) + this.errorHandlers.add(0, (request, next) -> next.handle(request) .onErrorResume(exceptionType, t -> responseProvider.apply(t, request))); + return this; } @Override @@ -332,12 +337,12 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { } RouterFunction result = new BuiltRouterFunction(this.routerFunctions); - if (this.filterFunctions.isEmpty()) { + if (this.filterFunctions.isEmpty() && this.errorHandlers.isEmpty()) { return result; } else { HandlerFilterFunction filter = - this.filterFunctions.stream() + Stream.concat(this.filterFunctions.stream(), this.errorHandlers.stream()) .reduce(HandlerFilterFunction::andThen) .orElseThrow(IllegalStateException::new); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionBuilderTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionBuilderTests.java index 7a5f5292fc..25d40bbd9f 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionBuilderTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionBuilderTests.java @@ -16,6 +16,7 @@ package org.springframework.web.reactive.function.server; +import java.io.IOException; import java.util.Collections; import java.util.concurrent.atomic.AtomicInteger; @@ -210,4 +211,26 @@ public class RouterFunctionBuilderTests { .verifyComplete(); } + @Test + public void multipleOnErrors() { + RouterFunction route = RouterFunctions.route() + .GET("/error", request -> Mono.error(new IOException())) + .onError(IOException.class, (t, r) -> ServerResponse.status(200).build()) + .onError(Exception.class, (t, r) -> ServerResponse.status(201).build()) + .build(); + + MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com/error").build(); + ServerRequest serverRequest = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList()); + + Mono responseStatus = route.route(serverRequest) + .flatMap(handlerFunction -> handlerFunction.handle(serverRequest)) + .map(ServerResponse::statusCode); + + StepVerifier.create(responseStatus) + .assertNext(status -> assertThat(status).isEqualTo(HttpStatus.OK)) + .verifyComplete(); + + } + + } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctionBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctionBuilder.java index 49af7797c8..77d1813cd6 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctionBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctionBuilder.java @@ -24,6 +24,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.Stream; import org.springframework.core.io.Resource; import org.springframework.http.HttpMethod; @@ -41,6 +42,8 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { private final List> filterFunctions = new ArrayList<>(); + private final List> errorHandlers = new ArrayList<>(); + @Override public RouterFunctions.Builder add(RouterFunction routerFunction) { @@ -307,7 +310,8 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { Assert.notNull(predicate, "Predicate must not be null"); Assert.notNull(responseProvider, "ResponseProvider must not be null"); - return filter(HandlerFilterFunction.ofErrorHandler(predicate, responseProvider)); + this.errorHandlers.add(0, HandlerFilterFunction.ofErrorHandler(predicate, responseProvider)); + return this; } @Override @@ -316,8 +320,7 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { Assert.notNull(exceptionType, "ExceptionType must not be null"); Assert.notNull(responseProvider, "ResponseProvider must not be null"); - return filter(HandlerFilterFunction.ofErrorHandler(exceptionType::isInstance, - responseProvider)); + return onError(exceptionType::isInstance, responseProvider); } @Override @@ -327,12 +330,12 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { } RouterFunction result = new BuiltRouterFunction(this.routerFunctions); - if (this.filterFunctions.isEmpty()) { + if (this.filterFunctions.isEmpty() && this.errorHandlers.isEmpty()) { return result; } else { HandlerFilterFunction filter = - this.filterFunctions.stream() + Stream.concat(this.filterFunctions.stream(), this.errorHandlers.stream()) .reduce(HandlerFilterFunction::andThen) .orElseThrow(IllegalStateException::new); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java index 55d0e484a9..7c3a7b856f 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java @@ -16,6 +16,7 @@ package org.springframework.web.servlet.function; +import java.io.IOException; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; @@ -52,36 +53,32 @@ class RouterFunctionBuilderTests { ServerRequest getFooRequest = initRequest("GET", "/foo"); - Optional responseStatus = route.route(getFooRequest) + Optional responseStatus = route.route(getFooRequest) .map(handlerFunction -> handle(handlerFunction, getFooRequest)) - .map(ServerResponse::statusCode) - .map(HttpStatus::value); - assertThat(responseStatus.get().intValue()).isEqualTo(200); + .map(ServerResponse::statusCode); + assertThat(responseStatus).contains(HttpStatus.OK); ServerRequest headFooRequest = initRequest("HEAD", "/foo"); responseStatus = route.route(headFooRequest) .map(handlerFunction -> handle(handlerFunction, getFooRequest)) - .map(ServerResponse::statusCode) - .map(HttpStatus::value); - assertThat(responseStatus.get().intValue()).isEqualTo(202); + .map(ServerResponse::statusCode); + assertThat(responseStatus).contains(HttpStatus.ACCEPTED); ServerRequest barRequest = initRequest("POST", "/", req -> req.setContentType("text/plain")); responseStatus = route.route(barRequest) .map(handlerFunction -> handle(handlerFunction, barRequest)) - .map(ServerResponse::statusCode) - .map(HttpStatus::value); - assertThat(responseStatus.get().intValue()).isEqualTo(204); + .map(ServerResponse::statusCode); + assertThat(responseStatus).contains(HttpStatus.NO_CONTENT); ServerRequest invalidRequest = initRequest("POST", "/"); responseStatus = route.route(invalidRequest) .map(handlerFunction -> handle(handlerFunction, invalidRequest)) - .map(ServerResponse::statusCode) - .map(HttpStatus::value); + .map(ServerResponse::statusCode); - assertThat(responseStatus.isPresent()).isFalse(); + assertThat(responseStatus).isEmpty(); } private static ServerResponse handle(HandlerFunction handlerFunction, @@ -105,19 +102,17 @@ class RouterFunctionBuilderTests { ServerRequest resourceRequest = initRequest("GET", "/resources/response.txt"); - Optional responseStatus = route.route(resourceRequest) + Optional responseStatus = route.route(resourceRequest) .map(handlerFunction -> handle(handlerFunction, resourceRequest)) - .map(ServerResponse::statusCode) - .map(HttpStatus::value); - assertThat(responseStatus.get().intValue()).isEqualTo(200); + .map(ServerResponse::statusCode); + assertThat(responseStatus).contains(HttpStatus.OK); ServerRequest invalidRequest = initRequest("POST", "/resources/foo.txt"); responseStatus = route.route(invalidRequest) .map(handlerFunction -> handle(handlerFunction, invalidRequest)) - .map(ServerResponse::statusCode) - .map(HttpStatus::value); - assertThat(responseStatus.isPresent()).isFalse(); + .map(ServerResponse::statusCode); + assertThat(responseStatus).isEmpty(); } @Test @@ -132,11 +127,10 @@ class RouterFunctionBuilderTests { ServerRequest fooRequest = initRequest("GET", "/foo/bar/baz"); - Optional responseStatus = route.route(fooRequest) + Optional responseStatus = route.route(fooRequest) .map(handlerFunction -> handle(handlerFunction, fooRequest)) - .map(ServerResponse::statusCode) - .map(HttpStatus::value); - assertThat(responseStatus.get().intValue()).isEqualTo(200); + .map(ServerResponse::statusCode); + assertThat(responseStatus).contains(HttpStatus.OK); } @Test @@ -181,13 +175,33 @@ class RouterFunctionBuilderTests { ServerRequest barRequest = initRequest("GET", "/bar"); - Optional responseStatus = route.route(barRequest) + Optional responseStatus = route.route(barRequest) .map(handlerFunction -> handle(handlerFunction, barRequest)) - .map(ServerResponse::statusCode) - .map(HttpStatus::value); - assertThat(responseStatus.get().intValue()).isEqualTo(500); + .map(ServerResponse::statusCode); + assertThat(responseStatus).contains(HttpStatus.INTERNAL_SERVER_ERROR); } + @Test + public void multipleOnErrors() { + RouterFunction route = RouterFunctions.route() + .GET("/error", request -> { + throw new IOException(); + }) + .onError(IOException.class, (t, r) -> ServerResponse.status(200).build()) + .onError(Exception.class, (t, r) -> ServerResponse.status(201).build()) + .build(); + + MockHttpServletRequest servletRequest = new MockHttpServletRequest("GET", "/error"); + ServerRequest serverRequest = new DefaultServerRequest(servletRequest, emptyList()); + + Optional responseStatus = route.route(serverRequest) + .map(handlerFunction -> handle(handlerFunction, serverRequest)) + .map(ServerResponse::statusCode); + assertThat(responseStatus).contains(HttpStatus.OK); + + } + + private ServerRequest initRequest(String httpMethod, String requestUri) { return initRequest(httpMethod, requestUri, null);