From 00a5106bfa664d9c310bdd5872927fec4faaf96c Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Tue, 7 May 2019 16:47:05 +0200 Subject: [PATCH] Add route(RequestPredicate, HandlerFunction) to RouterFunctions builder Closes gh-22701 --- .../servlet/function/RouterFunctionBuilder.java | 6 ++++++ .../web/servlet/function/RouterFunctions.java | 11 +++++++++++ .../function/RouterFunctionBuilderTests.java | 17 ++++++++++++++--- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctionBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctionBuilder.java index 75889036d6..5eb17d9191 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctionBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctionBuilder.java @@ -132,6 +132,12 @@ class RouterFunctionBuilder implements RouterFunctions.Builder { return add(RequestPredicates.OPTIONS(pattern), handlerFunction); } + @Override + public RouterFunctions.Builder route(RequestPredicate predicate, + HandlerFunction handlerFunction) { + return add(RouterFunctions.route(predicate, handlerFunction)); + } + @Override public RouterFunctions.Builder OPTIONS(String pattern, RequestPredicate predicate, HandlerFunction handlerFunction) { diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctions.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctions.java index 455773e516..5fa6a64774 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctions.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/RouterFunctions.java @@ -349,6 +349,17 @@ public abstract class RouterFunctions { */ Builder OPTIONS(String pattern, HandlerFunction handlerFunction); + /** + * Adds a route to the given handler function that handles all requests that match the + * given predicate. + * + * @param predicate the request predicate to match + * @param handlerFunction the handler function to handle all requests that match the predicate + * @return this builder + * @see RequestPredicates + */ + Builder route(RequestPredicate predicate, HandlerFunction handlerFunction); + /** * Adds a route to the given handler function that handles all HTTP {@code OPTIONS} requests * that match the given pattern and predicate. diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java index 44df89e3be..8fc95397a0 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/RouterFunctionBuilderTests.java @@ -29,6 +29,7 @@ import org.springframework.mock.web.test.MockHttpServletRequest; import static java.util.Collections.emptyList; import static org.junit.Assert.*; +import static org.springframework.web.servlet.function.RequestPredicates.HEAD; /** * @author Arjen Poutsma @@ -41,17 +42,27 @@ public class RouterFunctionBuilderTests { .GET("/foo", request -> ServerResponse.ok().build()) .POST("/", RequestPredicates.contentType(MediaType.TEXT_PLAIN), request -> ServerResponse.noContent().build()) + .route(HEAD("/foo"), request -> ServerResponse.accepted().build()) .build(); MockHttpServletRequest servletRequest = new MockHttpServletRequest("GET", "/foo"); - ServerRequest fooRequest = new DefaultServerRequest(servletRequest, emptyList()); + ServerRequest getFooRequest = new DefaultServerRequest(servletRequest, emptyList()); - Optional responseStatus = route.route(fooRequest) - .map(handlerFunction -> handle(handlerFunction, fooRequest)) + Optional responseStatus = route.route(getFooRequest) + .map(handlerFunction -> handle(handlerFunction, getFooRequest)) .map(ServerResponse::statusCode) .map(HttpStatus::value); assertEquals(200, responseStatus.get().intValue()); + servletRequest = new MockHttpServletRequest("HEAD", "/foo"); + ServerRequest headFooRequest = new DefaultServerRequest(servletRequest, emptyList()); + + responseStatus = route.route(headFooRequest) + .map(handlerFunction -> handle(handlerFunction, getFooRequest)) + .map(ServerResponse::statusCode) + .map(HttpStatus::value); + assertEquals(202, responseStatus.get().intValue()); + servletRequest = new MockHttpServletRequest("POST", "/"); servletRequest.setContentType("text/plain"); ServerRequest barRequest = new DefaultServerRequest(servletRequest, emptyList());