diff --git a/src/main/java/org/springframework/cloud/gateway/handler/RoutePredicateHandlerMapping.java b/src/main/java/org/springframework/cloud/gateway/handler/RoutePredicateHandlerMapping.java index 0bfd015e4..6e649093d 100644 --- a/src/main/java/org/springframework/cloud/gateway/handler/RoutePredicateHandlerMapping.java +++ b/src/main/java/org/springframework/cloud/gateway/handler/RoutePredicateHandlerMapping.java @@ -5,6 +5,8 @@ import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; import java.util.function.Predicate; import org.springframework.cloud.gateway.api.RouteLocator; @@ -25,23 +27,25 @@ import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.G */ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping { - private Map predicates = new LinkedHashMap<>(); - private RouteLocator routeLocator; - private WebHandler webHandler; + private final RouteLocator routeLocator; + private final WebHandler webHandler; + private final Map routePredicates = new LinkedHashMap<>(); + //TODO: define semeantics for refresh (ie clearing and recalculating combinedPredicates) + private final Map> combinedPredicates = new ConcurrentHashMap<>(); - public RoutePredicateHandlerMapping(WebHandler webHandler, Map predicates, + public RoutePredicateHandlerMapping(WebHandler webHandler, Map routePredicates, RouteLocator routeLocator) { this.webHandler = webHandler; this.routeLocator = routeLocator; - predicates.forEach((name, factory) -> { + routePredicates.forEach((name, factory) -> { String key = normalizeName(name); - if (this.predicates.containsKey(key)) { + if (this.routePredicates.containsKey(key)) { this.logger.warn("A RoutePredicate named "+ key - + " already exists, class: " + this.predicates.get(key) + + " already exists, class: " + this.routePredicates.get(key) + ". It will be overwritten."); } - this.predicates.put(key, factory); + this.routePredicates.put(key, factory); if (logger.isInfoEnabled()) { logger.info("Loaded RoutePredicate [" + key + "]"); } @@ -58,27 +62,21 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping { protected Mono getHandlerInternal(ServerWebExchange exchange) { exchange.getAttributes().put(GATEWAY_HANDLER_MAPPER_ATTR, getClass().getSimpleName()); - Route route; - try { - route = lookupRoute(exchange); - } - catch (Exception ex) { - return Mono.error(ex); - } - - if (route != null) { - if (this.logger.isDebugEnabled()) { - this.logger.debug("Mapping [" + getExchangeDesc(exchange) + "] to " + route); - } - - exchange.getAttributes().put(GATEWAY_ROUTE_ATTR, route); - return Mono.just(this.webHandler); - } - else if (this.logger.isTraceEnabled()) { - this.logger.trace("No Route found for [" + getExchangeDesc(exchange) + "]"); - } + return lookupRoute(exchange) + .log("TRACE") + .then((Function>) r -> { + if (logger.isDebugEnabled()) { + logger.debug("Mapping [" + getExchangeDesc(exchange) + "] to " + r); + } - return Mono.empty(); + exchange.getAttributes().put(GATEWAY_ROUTE_ATTR, r); + return Mono.just(webHandler); + }).otherwiseIfEmpty(Mono.empty().then(() -> { + if (logger.isTraceEnabled()) { + logger.trace("No Route found for [" + getExchangeDesc(exchange) + "]"); + } + return Mono.empty(); + })); } //TODO: get desc from factory? @@ -92,27 +90,42 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping { } - protected Route lookupRoute(ServerWebExchange exchange) throws Exception { - List routes = this.routeLocator.getRoutes().collectList().block(); //TODO: convert rest of class to Reactive - - for (Route route : routes) { - if (!route.getPredicates().isEmpty()) { + protected Mono lookupRoute(ServerWebExchange exchange) { + return this.routeLocator.getRoutes() //TODO: cache predicate - Predicate predicate = combinePredicates(route); - if (predicate.test(exchange)) { + .map(route -> getRouteCombinedPredicates(route)) + .filter(rcp -> rcp.combinedPredicate.test(exchange)) + .next() + //TODO: error handling + .map(rcp -> { if (logger.isDebugEnabled()) { - logger.debug("Route matched: " + route.getId()); - } - validateRoute(route, exchange); - return route; - } else { - if (logger.isTraceEnabled()) { - logger.trace("Route did not match: " + route.getId()); + logger.debug("Route matched: " + rcp.route.getId()); } - } - } + validateRoute(rcp.route, exchange); + return rcp.route; + }); + + /* TODO: trace logging + if (logger.isTraceEnabled()) { + logger.trace("Route did not match: " + route.getId()); + }*/ + } + + private RouteCombinedPredicates getRouteCombinedPredicates(Route route) { + Predicate predicate = this.combinedPredicates + .computeIfAbsent(route.getId(), k -> combinePredicates(route)); + + return new RouteCombinedPredicates(route, predicate); + } + + private class RouteCombinedPredicates { + private Route route; + private Predicate combinedPredicate; + + public RouteCombinedPredicates(Route route, Predicate combinedPredicate) { + this.route = route; + this.combinedPredicate = combinedPredicate; } - return null; } @@ -129,7 +142,7 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping { } private Predicate lookup(Route route, PredicateDefinition predicate) { - RoutePredicate found = this.predicates.get(predicate.getName()); + RoutePredicate found = this.routePredicates.get(predicate.getName()); if (found == null) { throw new IllegalArgumentException("Unable to find RoutePredicate with name " + predicate.getName()); } @@ -155,7 +168,7 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping { * @throws Exception if validation failed */ @SuppressWarnings("UnusedParameters") - protected void validateRoute(Route route, ServerWebExchange exchange) throws Exception { + protected void validateRoute(Route route, ServerWebExchange exchange) { } }