diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicate.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicate.java index 0f23918f74..36745e265c 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicate.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicate.java @@ -84,4 +84,15 @@ public interface RequestPredicate { return (test(request) ? Optional.of(request) : Optional.empty()); } + /** + * Accept the given visitor. Default implementation calls + * {@link RequestPredicates.Visitor#unknown(RequestPredicate)}; composed {@code RequestPredicate} + * implementations are expected to call {@code accept} for all components that make up this + * request predicate. + * @param visitor the visitor to accept + */ + default void accept(RequestPredicates.Visitor visitor) { + visitor.unknown(this); + } + } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java index d8396ea7c5..1e0ae421f7 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java @@ -41,12 +41,14 @@ import reactor.core.publisher.Mono; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.multipart.Part; import org.springframework.http.server.PathContainer; import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; @@ -146,24 +148,7 @@ public abstract class RequestPredicates { */ public static RequestPredicate contentType(MediaType... mediaTypes) { Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); - Set mediaTypeSet = new HashSet<>(Arrays.asList(mediaTypes)); - - return headers(new Predicate() { - @Override - public boolean test(ServerRequest.Headers headers) { - MediaType contentType = - headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); - boolean match = mediaTypeSet.stream() - .anyMatch(mediaType -> mediaType.includes(contentType)); - traceMatch("Content-Type", mediaTypeSet, contentType, match); - return match; - } - - @Override - public String toString() { - return String.format("Content-Type: %s", mediaTypeSet); - } - }); + return new ContentTypePredicate(mediaTypes); } /** @@ -175,29 +160,7 @@ public abstract class RequestPredicates { */ public static RequestPredicate accept(MediaType... mediaTypes) { Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty"); - Set mediaTypeSet = new HashSet<>(Arrays.asList(mediaTypes)); - - return headers(new Predicate() { - @Override - public boolean test(ServerRequest.Headers headers) { - List acceptedMediaTypes = headers.accept(); - if (acceptedMediaTypes.isEmpty()) { - acceptedMediaTypes = Collections.singletonList(MediaType.ALL); - } - else { - MediaType.sortBySpecificityAndQuality(acceptedMediaTypes); - } - boolean match = acceptedMediaTypes.stream() - .anyMatch(acceptedMediaType -> mediaTypeSet.stream() - .anyMatch(acceptedMediaType::isCompatibleWith)); - traceMatch("Accept", mediaTypeSet, acceptedMediaTypes, match); - return match; - } - @Override - public String toString() { - return String.format("Accept: %s", mediaTypeSet); - } - }); + return new AcceptPredicate(mediaTypes); } /** @@ -284,18 +247,7 @@ public abstract class RequestPredicates { */ public static RequestPredicate pathExtension(String extension) { Assert.notNull(extension, "'extension' must not be null"); - return pathExtension(new Predicate() { - @Override - public boolean test(String pathExtension) { - boolean match = extension.equalsIgnoreCase(pathExtension); - traceMatch("Extension", extension, pathExtension, match); - return match; - } - - public String toString() { - return String.format("*.%s", extension); - } - }); + return new PathExtensionPredicate(extension); } /** @@ -319,16 +271,7 @@ public abstract class RequestPredicates { * @see ServerRequest#queryParam(String) */ public static RequestPredicate queryParam(String name, String value) { - return queryParam(name, new Predicate() { - @Override - public boolean test(String s) { - return s.equals(value); - } - @Override - public String toString() { - return String.format("== %s", value); - } - }); + return new QueryParamPredicate(name, value); } /** @@ -379,6 +322,100 @@ public abstract class RequestPredicates { } + + /** + * Receives notifications from the logical structure of request predicates. + */ + public interface Visitor { + + /** + * Receive notification of an HTTP method predicate. + * @param methods the HTTP methods that make up the predicate + * @see RequestPredicates#method(HttpMethod) + */ + void method(Set methods); + + /** + * Receive notification of an path predicate. + * @param pattern the path pattern that makes up the predicate + * @see RequestPredicates#path(String) + */ + void path(String pattern); + + /** + * Receive notification of an path extension predicate. + * @param extension the path extension that makes up the predicate + * @see RequestPredicates#pathExtension(String) + */ + void pathExtension(String extension); + + /** + * Receive notification of a HTTP header predicate. + * @param name the name of the HTTP header to check + * @param value the desired value of the HTTP header + * @see RequestPredicates#headers(Predicate) + * @see RequestPredicates#contentType(MediaType...) + * @see RequestPredicates#accept(MediaType...) + */ + void header(String name, String value); + + /** + * Receive notification of a query parameter predicate. + * @param name the name of the query parameter + * @param value the desired value of the parameter + * @see RequestPredicates#queryParam(String, String) + */ + void queryParam(String name, String value); + + /** + * Receive first notification of a logical AND predicate. + * The first subsequent notification will contain the left-hand side of the AND-predicate; + * the second notification contains the right-hand side, followed by {@link #endAnd()}. + * @see RequestPredicate#and(RequestPredicate) + */ + void startAnd(); + + /** + * Receive last notification of a logical AND predicate. + * @see RequestPredicate#and(RequestPredicate) + */ + void endAnd(); + + /** + * Receive first notification of a logical OR predicate. + * The first subsequent notification will contain the left-hand side of the OR-predicate; + * the second notification contains the right-hand side, followed by {@link #endOr()}. + * @see RequestPredicate#or(RequestPredicate) + */ + void startOr(); + + /** + * Receive last notification of a logical OR predicate. + * @see RequestPredicate#or(RequestPredicate) + */ + void endOr(); + + /** + * Receive first notification of a negated predicate. + * The first subsequent notification will contain the negated predicated, followed + * by {@link #endNegate()}. + * @see RequestPredicate#negate() + */ + void startNegate(); + + /** + * Receive last notification of a negated predicate. + * @see RequestPredicate#negate() + */ + void endNegate(); + + /** + * Receive first notification of an unknown predicate. + */ + void unknown(RequestPredicate predicate); + } + + private static class HttpMethodPredicate implements RequestPredicate { private final Set httpMethods; @@ -401,6 +438,11 @@ public abstract class RequestPredicates { return match; } + @Override + public void accept(Visitor visitor) { + visitor.method(Collections.unmodifiableSet(this.httpMethods)); + } + @Override public String toString() { if (this.httpMethods.size() == 1) { @@ -454,6 +496,11 @@ public abstract class RequestPredicates { .map(info -> new SubPathServerRequestWrapper(request, info, this.pattern)); } + @Override + public void accept(Visitor visitor) { + visitor.path(this.pattern.getPatternString()); + } + @Override public String toString() { return this.pattern.getPatternString(); @@ -481,14 +528,115 @@ public abstract class RequestPredicates { } } + private static class ContentTypePredicate extends HeadersPredicate { + + private final Set mediaTypes; + + public ContentTypePredicate(MediaType... mediaTypes) { + this(new HashSet<>(Arrays.asList(mediaTypes))); + } + + private ContentTypePredicate(Set mediaTypes) { + super(headers -> { + MediaType contentType = + headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); + boolean match = mediaTypes.stream() + .anyMatch(mediaType -> mediaType.includes(contentType)); + traceMatch("Content-Type", mediaTypes, contentType, match); + return match; + }); + this.mediaTypes = mediaTypes; + } + + @Override + public void accept(Visitor visitor) { + visitor.header(HttpHeaders.CONTENT_TYPE, + (this.mediaTypes.size() == 1) ? + this.mediaTypes.iterator().next().toString() : + this.mediaTypes.toString()); + } + + @Override + public String toString() { + return String.format("Content-Type: %s", + (this.mediaTypes.size() == 1) ? + this.mediaTypes.iterator().next().toString() : + this.mediaTypes.toString()); + } + } + + private static class AcceptPredicate extends HeadersPredicate { + + private final Set mediaTypes; + + public AcceptPredicate(MediaType... mediaTypes) { + this(new HashSet<>(Arrays.asList(mediaTypes))); + } + + private AcceptPredicate(Set mediaTypes) { + super(headers -> { + List acceptedMediaTypes = acceptedMediaTypes(headers); + boolean match = acceptedMediaTypes.stream() + .anyMatch(acceptedMediaType -> mediaTypes.stream() + .anyMatch(acceptedMediaType::isCompatibleWith)); + traceMatch("Accept", mediaTypes, acceptedMediaTypes, match); + return match; + }); + this.mediaTypes = mediaTypes; + } + + @NonNull + private static List acceptedMediaTypes(ServerRequest.Headers headers) { + List acceptedMediaTypes = headers.accept(); + if (acceptedMediaTypes.isEmpty()) { + acceptedMediaTypes = Collections.singletonList(MediaType.ALL); + } + else { + MediaType.sortBySpecificityAndQuality(acceptedMediaTypes); + } + return acceptedMediaTypes; + } + + @Override + public void accept(Visitor visitor) { + visitor.header(HttpHeaders.ACCEPT, + (this.mediaTypes.size() == 1) ? + this.mediaTypes.iterator().next().toString() : + this.mediaTypes.toString()); + } + + @Override + public String toString() { + return String.format("Accept: %s", + (this.mediaTypes.size() == 1) ? + this.mediaTypes.iterator().next().toString() : + this.mediaTypes.toString()); + } + } + private static class PathExtensionPredicate implements RequestPredicate { private final Predicate extensionPredicate; + @Nullable + private final String extension; + public PathExtensionPredicate(Predicate extensionPredicate) { Assert.notNull(extensionPredicate, "Predicate must not be null"); this.extensionPredicate = extensionPredicate; + this.extension = null; + } + + public PathExtensionPredicate(String extension) { + Assert.notNull(extension, "Extension must not be null"); + + this.extensionPredicate = s -> { + boolean match = extension.equalsIgnoreCase(s); + traceMatch("Extension", extension, s, match); + return match; + }; + this.extension = extension; } @Override @@ -497,9 +645,20 @@ public abstract class RequestPredicates { return this.extensionPredicate.test(pathExtension); } + @Override + public void accept(Visitor visitor) { + visitor.pathExtension( + (this.extension != null) ? + this.extension : + this.extensionPredicate.toString()); + } + @Override public String toString() { - return this.extensionPredicate.toString(); + return String.format("*.%s", + (this.extension != null) ? + this.extension : + this.extensionPredicate); } } @@ -509,24 +668,47 @@ public abstract class RequestPredicates { private final String name; - private final Predicate predicate; + private final Predicate valuePredicate; + + @Nullable + private final String value; + + public QueryParamPredicate(String name, Predicate valuePredicate) { + Assert.notNull(name, "Name must not be null"); + Assert.notNull(valuePredicate, "Predicate must not be null"); + this.name = name; + this.valuePredicate = valuePredicate; + this.value = null; + } - public QueryParamPredicate(String name, Predicate predicate) { + public QueryParamPredicate(String name, String value) { Assert.notNull(name, "Name must not be null"); - Assert.notNull(predicate, "Predicate must not be null"); + Assert.notNull(value, "Value must not be null"); this.name = name; - this.predicate = predicate; + this.valuePredicate = value::equals; + this.value = value; } @Override public boolean test(ServerRequest request) { Optional s = request.queryParam(this.name); - return s.filter(this.predicate).isPresent(); + return s.filter(this.valuePredicate).isPresent(); + } + + @Override + public void accept(Visitor visitor) { + visitor.queryParam(this.name, + (this.value != null) ? + this.value : + this.valuePredicate.toString()); } @Override public String toString() { - return String.format("?%s %s", this.name, this.predicate); + return String.format("?%s %s", this.name, + (this.value != null) ? + this.value : + this.valuePredicate); } } @@ -564,6 +746,14 @@ public abstract class RequestPredicates { return this.left.nest(request).flatMap(this.right::nest); } + @Override + public void accept(Visitor visitor) { + visitor.startAnd(); + this.left.accept(visitor); + this.right.accept(visitor); + visitor.endAnd(); + } + @Override public String toString() { return String.format("(%s && %s)", this.left, this.right); @@ -591,6 +781,13 @@ public abstract class RequestPredicates { return result; } + @Override + public void accept(Visitor visitor) { + visitor.startNegate(); + this.delegate.accept(visitor); + visitor.endNegate(); + } + @Override public String toString() { return "!" + this.delegate.toString(); @@ -642,6 +839,15 @@ public abstract class RequestPredicates { } } + @Override + public void accept(Visitor visitor) { + visitor.startOr(); + this.left.accept(visitor); + this.right.accept(visitor); + visitor.endOr(); + } + + @Override public String toString() { return String.format("(%s || %s)", this.left, this.right); diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunction.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunction.java index 1326cec6ec..d3cfc95358 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunction.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunction.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -108,7 +108,7 @@ public interface RouterFunction { * Accept the given visitor. Default implementation calls * {@link RouterFunctions.Visitor#unknown(RouterFunction)}; composed {@code RouterFunction} * implementations are expected to call {@code accept} for all components that make up this - * router function + * router function. * @param visitor the visitor to accept */ default void accept(RouterFunctions.Visitor visitor) { diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ToStringVisitor.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ToStringVisitor.java index dc5d721354..b5a072f1b5 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ToStringVisitor.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/ToStringVisitor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,14 @@ package org.springframework.web.reactive.function.server; +import java.util.Set; import java.util.function.Function; import reactor.core.publisher.Mono; import org.springframework.core.io.Resource; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; /** * Implementation of {@link RouterFunctions.Visitor} that creates a formatted string representation @@ -29,7 +32,7 @@ import org.springframework.core.io.Resource; * @author Arjen Poutsma * @since 5.0 */ -class ToStringVisitor implements RouterFunctions.Visitor { +class ToStringVisitor implements RouterFunctions.Visitor, RequestPredicates.Visitor { private static final String NEW_LINE = System.getProperty("line.separator", "\\n"); @@ -37,10 +40,15 @@ class ToStringVisitor implements RouterFunctions.Visitor { private int indent = 0; + @Nullable + private String infix; + + // RouterFunctions.Visitor + @Override public void startNested(RequestPredicate predicate) { indent(); - this.builder.append(predicate); + predicate.accept(this); this.builder.append(" => {"); this.builder.append(NEW_LINE); this.indent++; @@ -57,7 +65,7 @@ class ToStringVisitor implements RouterFunctions.Visitor { @Override public void route(RequestPredicate predicate, HandlerFunction handlerFunction) { indent(); - this.builder.append(predicate); + predicate.accept(this); this.builder.append(" -> "); this.builder.append(handlerFunction); this.builder.append(NEW_LINE); @@ -82,6 +90,91 @@ class ToStringVisitor implements RouterFunctions.Visitor { } } + // RequestPredicates.Visitor + + @Override + public void method(Set methods) { + if (methods.size() == 1) { + this.builder.append(methods.iterator().next()); + } + else { + this.builder.append(methods); + } + infix(); + } + + @Override + public void path(String pattern) { + this.builder.append(pattern); + infix(); + } + + @Override + public void pathExtension(String extension) { + this.builder.append(String.format("*.%s", extension)); + infix(); + } + + @Override + public void header(String name, String value) { + this.builder.append(String.format("%s: %s", name, value)); + infix(); + } + + @Override + public void queryParam(String name, String value) { + this.builder.append(String.format("?%s == %s", name, value)); + infix(); + } + + @Override + public void startAnd() { + this.builder.append('('); + this.infix = "&&"; + } + + @Override + public void endAnd() { + this.builder.append(')'); + } + + @Override + public void startOr() { + this.builder.append('('); + this.infix = "||"; + } + + @Override + public void endOr() { + this.builder.append(')'); + } + + @Override + public void startNegate() { + this.builder.append("!("); + + } + + @Override + public void endNegate() { + this.builder.append(')'); + } + + @Override + public void unknown(RequestPredicate predicate) { + this.builder.append(predicate); + } + + private void infix() { + if (this.infix != null) { + this.builder.append(' '); + this.builder.append(this.infix); + this.builder.append(' '); + this.infix = null; + } + } + + @Override public String toString() { String result = this.builder.toString(); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/ToStringVisitorTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/ToStringVisitorTests.java new file mode 100644 index 0000000000..3469fab36d --- /dev/null +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/ToStringVisitorTests.java @@ -0,0 +1,107 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; + +import static org.junit.Assert.*; +import static org.springframework.web.reactive.function.server.RequestPredicates.accept; +import static org.springframework.web.reactive.function.server.RequestPredicates.contentType; +import static org.springframework.web.reactive.function.server.RequestPredicates.method; +import static org.springframework.web.reactive.function.server.RequestPredicates.methods; +import static org.springframework.web.reactive.function.server.RequestPredicates.path; +import static org.springframework.web.reactive.function.server.RequestPredicates.pathExtension; +import static org.springframework.web.reactive.function.server.RequestPredicates.queryParam; +import static org.springframework.web.reactive.function.server.RouterFunctions.route; + +/** + * @author Arjen Poutsma + */ +public class ToStringVisitorTests { + + @Test + public void nested() { + HandlerFunction handler = new SimpleHandlerFunction(); + RouterFunction routerFunction = route() + .path("/foo", builder -> { + builder.path("/bar", () -> route() + .GET("/baz", handler) + .build()); + }) + .build(); + + ToStringVisitor visitor = new ToStringVisitor(); + routerFunction.accept(visitor); + String result = visitor.toString(); + + String expected = "/foo => {\n" + + " /bar => {\n" + + " (GET && /baz) -> \n" + + " }\n" + + "}"; + assertEquals(expected, result); + } + + @Test + public void predicates() { + testPredicate(methods(HttpMethod.GET), "GET"); + testPredicate(methods(HttpMethod.GET, HttpMethod.POST), "[GET, POST]"); + + testPredicate(path("/foo"), "/foo"); + + testPredicate(pathExtension("foo"), "*.foo"); + + testPredicate(contentType(MediaType.APPLICATION_JSON), "Content-Type: application/json"); + testPredicate(contentType(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN), "Content-Type: [application/json, text/plain]"); + + testPredicate(accept(MediaType.APPLICATION_JSON), "Accept: application/json"); + + testPredicate(queryParam("foo", "bar"), "?foo == bar"); + + testPredicate(method(HttpMethod.GET).and(path("/foo")), "(GET && /foo)"); + + testPredicate(method(HttpMethod.GET).or(path("/foo")), "(GET || /foo)"); + + testPredicate(method(HttpMethod.GET).negate(), "!(GET)"); + } + + private void testPredicate(RequestPredicate predicate, String expected) { + ToStringVisitor visitor = new ToStringVisitor(); + predicate.accept(visitor); + String result = visitor.toString(); + + assertEquals(expected, result); + } + + private static class SimpleHandlerFunction implements HandlerFunction { + + @Override + public Mono handle(ServerRequest request) { + return ServerResponse.ok().build(); + } + + @Override + public String toString() { + return ""; + } + } + +} \ No newline at end of file