diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java index 613595b2c0..6d6339b7b2 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java @@ -36,6 +36,7 @@ import org.springframework.util.PathMatcher; import org.springframework.web.reactive.function.BodyExtractor; import org.springframework.web.server.WebSession; import org.springframework.web.util.ParsingPathMatcher; +import org.springframework.web.util.UriUtils; /** * Implementations of {@link RequestPredicate} that implement various useful request matching operations, such as @@ -85,7 +86,10 @@ public abstract class RequestPredicates { * @return a predicate that tests against the given path pattern */ public static RequestPredicate path(String pattern, PathMatcher pathMatcher) { - return new PathPredicate(pattern, pathMatcher); + Assert.notNull(pattern, "'pattern' must not be null"); + Assert.notNull(pathMatcher, "'pathMatcher' must not be null"); + + return new PathMatchingPredicate(pattern, pathMatcher); } /** @@ -95,7 +99,7 @@ public abstract class RequestPredicates { * @return a predicate that tests against the given header predicate */ public static RequestPredicate headers(Predicate headersPredicate) { - return new HeaderPredicates(headersPredicate); + return new HeadersPredicate(headersPredicate); } /** @@ -221,6 +225,90 @@ public abstract class RequestPredicates { return method(HttpMethod.OPTIONS).and(path(pattern)); } + /** + * Return a {@code RequestPredicate} that matches if the request's path has the given extension. + * @param extension the path extension to match against + * @return a predicate that matches if the request's path has the given file extension + */ + public static RequestPredicate pathExtension(String extension) { + Assert.notNull(extension, "'extension' must not be null"); + return pathExtension(extension::equalsIgnoreCase); + } + + /** + * Return a {@code RequestPredicate} that matches if the request's path matches the given + * predicate. + * @param extensionPredicate the predicate to test against the request path extension + * @return a predicate that matches if the given predicate matches against the request's path + * file extension + */ + public static RequestPredicate pathExtension(Predicate extensionPredicate) { + Assert.notNull(extensionPredicate, "'extensionPredicate' must not be null"); + return request -> { + String pathExtension = UriUtils.extractFileExtension(request.path()); + return extensionPredicate.test(pathExtension); + }; + } + + /** + * Return a {@code RequestPredicate} that tests the request's query parameter of the given name + * against the given predicate. + * @param name the name of the query parameter to test against + * @param predicate predicate to test against the query parameter value + * @return a predicate that matches the given predicate against the query parameter of the given + * name + * @see ServerRequest#queryParam(String) + */ + public static RequestPredicate queryParam(String name, Predicate predicate) { + return request -> { + Optional s = request.queryParam(name); + return s.filter(predicate).isPresent(); + }; + } + + /** + * Return a {@code RequestPredicate} that matches JSON requests. The returned predicate + * matches if the request has {@code application/json} in the {@code Accept} header, or if the + * request path has a {@code .json} file extension. + * + * @return a predicate that matches JSON + * @see #accept(MediaType...) + * @see #pathExtension(String) + */ + public static RequestPredicate json() { + return accept(MediaType.APPLICATION_JSON) + .or(pathExtension("json")); + } + + /** + * Return a {@code RequestPredicate} that matches HTML requests. The returned predicate + * matches if the request has {@code text/html} in the {@code Accept} header, or if the request + * path has a {@code .html} file extension. + * + * @return a predicate that matches HTML requests + * @see #accept(MediaType...) + * @see #pathExtension(String) + */ + public static RequestPredicate html() { + return accept(MediaType.TEXT_HTML) + .or(pathExtension("html")); + } + + /** + * Return a {@code RequestPredicate} that matches XML requests. The returned predicate + * matches if the request has {@code text/xml} or {@code application/xml} in the {@code Accept} + * header, or if the request path has a {@code .xml} file extension. + * + * @return a predicate that matches XML requests + * @see #accept(MediaType...) + * @see #pathExtension(String) + */ + public static RequestPredicate xml() { + return accept(MediaType.TEXT_XML) + .or(accept(MediaType.APPLICATION_XML)) + .or(pathExtension("xml")); + } + private static class HttpMethodPredicate implements RequestPredicate { private final HttpMethod httpMethod; @@ -236,13 +324,13 @@ public abstract class RequestPredicates { } } - private static class PathPredicate implements RequestPredicate { + private static class PathMatchingPredicate implements RequestPredicate { private final String pattern; private final PathMatcher pathMatcher; - public PathPredicate(String pattern, PathMatcher pathMatcher) { + public PathMatchingPredicate(String pattern, PathMatcher pathMatcher) { Assert.notNull(pattern, "'pattern' must not be null"); Assert.notNull(pathMatcher, "'pathMatcher' must not be null"); this.pattern = pattern; @@ -273,11 +361,11 @@ public abstract class RequestPredicates { } } - private static class HeaderPredicates implements RequestPredicate { + private static class HeadersPredicate implements RequestPredicate { private final Predicate headersPredicate; - public HeaderPredicates(Predicate headersPredicate) { + public HeadersPredicate(Predicate headersPredicate) { Assert.notNull(headersPredicate, "'headersPredicate' must not be null"); this.headersPredicate = headersPredicate; } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java index 7fa2bfdc8b..eb149e6a9b 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -130,4 +130,96 @@ public class RequestPredicatesTests { assertFalse(predicate.test(request)); } + @Test + public void pathExtension() throws Exception { + URI uri = URI.create("http://localhost/file.txt"); + RequestPredicate predicate = RequestPredicates.pathExtension("txt"); + MockServerRequest request = MockServerRequest.builder().uri(uri).build(); + assertTrue(predicate.test(request)); + + predicate = RequestPredicates.pathExtension("bar"); + assertFalse(predicate.test(request)); + + uri = URI.create("http://localhost/file.foo"); + request = MockServerRequest.builder().uri(uri).build(); + assertFalse(predicate.test(request)); + + } + + @Test + public void queryParam() throws Exception { + MockServerRequest request = MockServerRequest.builder().queryParam("foo", "bar").build(); + RequestPredicate predicate = RequestPredicates.queryParam("foo", s -> s.equals("bar")); + assertTrue(predicate.test(request)); + + predicate = RequestPredicates.queryParam("foo", s -> s.equals("baz")); + assertFalse(predicate.test(request)); + } + + @Test + public void json() throws Exception { + RequestPredicate predicate = RequestPredicates.json(); + MockServerRequest request = MockServerRequest.builder().header("Accept", MediaType.APPLICATION_JSON.toString()).build(); + assertTrue(predicate.test(request)); + + request = MockServerRequest.builder().header("Accept", MediaType.TEXT_HTML.toString()).build(); + assertFalse(predicate.test(request)); + + URI uri = URI.create("http://localhost/file.json"); + request = MockServerRequest.builder().uri(uri).build(); + assertTrue(predicate.test(request)); + + uri = URI.create("http://localhost/file.html"); + request = MockServerRequest.builder().uri(uri).build(); + assertFalse(predicate.test(request)); + + request = MockServerRequest.builder().build(); + assertFalse(predicate.test(request)); + } + + @Test + public void html() throws Exception { + RequestPredicate predicate = RequestPredicates.html(); + MockServerRequest request = MockServerRequest.builder().header("Accept", MediaType.TEXT_HTML.toString()).build(); + assertTrue(predicate.test(request)); + + request = MockServerRequest.builder().header("Accept", MediaType.APPLICATION_JSON.toString()).build(); + assertFalse(predicate.test(request)); + + URI uri = URI.create("http://localhost/file.html"); + request = MockServerRequest.builder().uri(uri).build(); + assertTrue(predicate.test(request)); + + uri = URI.create("http://localhost/file.json"); + request = MockServerRequest.builder().uri(uri).build(); + assertFalse(predicate.test(request)); + + request = MockServerRequest.builder().build(); + assertFalse(predicate.test(request)); + } + + @Test + public void xml() throws Exception { + RequestPredicate predicate = RequestPredicates.xml(); + MockServerRequest request = MockServerRequest.builder().header("Accept", MediaType.TEXT_XML.toString()).build(); + assertTrue(predicate.test(request)); + + request = MockServerRequest.builder().header("Accept", MediaType.APPLICATION_XML.toString()).build(); + assertTrue(predicate.test(request)); + + request = MockServerRequest.builder().header("Accept", MediaType.TEXT_HTML.toString()).build(); + assertFalse(predicate.test(request)); + + URI uri = URI.create("http://localhost/file.xml"); + request = MockServerRequest.builder().uri(uri).build(); + assertTrue(predicate.test(request)); + + uri = URI.create("http://localhost/file.json"); + request = MockServerRequest.builder().uri(uri).build(); + assertFalse(predicate.test(request)); + + request = MockServerRequest.builder().build(); + assertFalse(predicate.test(request)); + } + } \ No newline at end of file