@ -5,6 +5,8 @@ import java.util.Collections;
@@ -5,6 +5,8 @@ import java.util.Collections;
import java.util.LinkedHashMap ;
import java.util.List ;
import java.util.Map ;
import java.util.concurrent.ConcurrentHashMap ;
import java.util.function.Function ;
import java.util.function.Predicate ;
import org.springframework.cloud.gateway.api.RouteLocator ;
@ -25,23 +27,25 @@ import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.G
@@ -25,23 +27,25 @@ import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.G
* /
public class RoutePredicateHandlerMapping extends AbstractHandlerMapping {
private Map < String , RoutePredicate > predicates = new LinkedHashMap < > ( ) ;
private RouteLocator routeLocator ;
private WebHandler webHandler ;
private final RouteLocator routeLocator ;
private final WebHandler webHandler ;
private final Map < String , RoutePredicate > routePredicates = new LinkedHashMap < > ( ) ;
//TODO: define semeantics for refresh (ie clearing and recalculating combinedPredicates)
private final Map < String , Predicate < ServerWebExchange > > combinedPredicates = new ConcurrentHashMap < > ( ) ;
public RoutePredicateHandlerMapping ( WebHandler webHandler , Map < String , RoutePredicate > predicates ,
public RoutePredicateHandlerMapping ( WebHandler webHandler , Map < String , RoutePredicate > routeP redicates,
RouteLocator routeLocator ) {
this . webHandler = webHandler ;
this . routeLocator = routeLocator ;
p redicates. forEach ( ( name , factory ) - > {
routeP redicates. forEach ( ( name , factory ) - > {
String key = normalizeName ( name ) ;
if ( this . p redicates. containsKey ( key ) ) {
if ( this . routeP redicates. containsKey ( key ) ) {
this . logger . warn ( "A RoutePredicate named " + key
+ " already exists, class: " + this . p redicates. get ( key )
+ " already exists, class: " + this . routeP redicates. get ( key )
+ ". It will be overwritten." ) ;
}
this . p redicates. put ( key , factory ) ;
this . routeP redicates. put ( key , factory ) ;
if ( logger . isInfoEnabled ( ) ) {
logger . info ( "Loaded RoutePredicate [" + key + "]" ) ;
}
@ -58,27 +62,21 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping {
@@ -58,27 +62,21 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping {
protected Mono < ? > getHandlerInternal ( ServerWebExchange exchange ) {
exchange . getAttributes ( ) . put ( GATEWAY_HANDLER_MAPPER_ATTR , getClass ( ) . getSimpleName ( ) ) ;
Route route ;
try {
route = lookupRoute ( exchange ) ;
}
catch ( Exception ex ) {
return Mono . error ( ex ) ;
}
if ( route ! = null ) {
if ( this . logger . isDebugEnabled ( ) ) {
this . logger . debug ( "Mapping [" + getExchangeDesc ( exchange ) + "] to " + route ) ;
return lookupRoute ( exchange )
. log ( "TRACE" )
. then ( ( Function < Route , Mono < ? > > ) r - > {
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Mapping [" + getExchangeDesc ( exchange ) + "] to " + r ) ;
}
exchange . getAttributes ( ) . put ( GATEWAY_ROUTE_ATTR , route ) ;
return Mono . just ( this . webHandler ) ;
}
else if ( this . logger . isTraceEnabled ( ) ) {
this . logger . trace ( "No Route found for [" + getExchangeDesc ( exchange ) + "]" ) ;
exchange . getAttributes ( ) . put ( GATEWAY_ROUTE_ATTR , r ) ;
return Mono . just ( webHandler ) ;
} ) . otherwiseIfEmpty ( Mono . empty ( ) . then ( ( ) - > {
if ( logger . isTraceEnabled ( ) ) {
logger . trace ( "No Route found for [" + getExchangeDesc ( exchange ) + "]" ) ;
}
return Mono . empty ( ) ;
} ) ) ;
}
//TODO: get desc from factory?
@ -92,28 +90,43 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping {
@@ -92,28 +90,43 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping {
}
protected Route lookupRoute ( ServerWebExchange exchange ) throws Exception {
List < Route > routes = this . routeLocator . getRoutes ( ) . collectList ( ) . block ( ) ; //TODO: convert rest of class to Reactive
for ( Route route : routes ) {
if ( ! route . getPredicates ( ) . isEmpty ( ) ) {
protected Mono < Route > lookupRoute ( ServerWebExchange exchange ) {
return this . routeLocator . getRoutes ( )
//TODO: cache predicate
Predicate < ServerWebExchange > predicate = combinePredicates ( route ) ;
if ( predicate . test ( exchange ) ) {
. map ( route - > getRouteCombinedPredicates ( route ) )
. filter ( rcp - > rcp . combinedPredicate . test ( exchange ) )
. next ( )
//TODO: error handling
. map ( rcp - > {
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Route matched: " + route . getId ( ) ) ;
logger . debug ( "Route matched: " + rcp . r oute . getId ( ) ) ;
}
validateRoute ( route , exchange ) ;
return route ;
} else {
validateRoute ( rcp . route , exchange ) ;
return rcp . route ;
} ) ;
/ * TODO : trace logging
if ( logger . isTraceEnabled ( ) ) {
logger . trace ( "Route did not match: " + route . getId ( ) ) ;
} * /
}
private RouteCombinedPredicates getRouteCombinedPredicates ( Route route ) {
Predicate < ServerWebExchange > predicate = this . combinedPredicates
. computeIfAbsent ( route . getId ( ) , k - > combinePredicates ( route ) ) ;
return new RouteCombinedPredicates ( route , predicate ) ;
}
private class RouteCombinedPredicates {
private Route route ;
private Predicate < ServerWebExchange > combinedPredicate ;
public RouteCombinedPredicates ( Route route , Predicate < ServerWebExchange > combinedPredicate ) {
this . route = route ;
this . combinedPredicate = combinedPredicate ;
}
}
return null ;
}
private Predicate < ServerWebExchange > combinePredicates ( Route route ) {
@ -129,7 +142,7 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping {
@@ -129,7 +142,7 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping {
}
private Predicate < ServerWebExchange > lookup ( Route route , PredicateDefinition predicate ) {
RoutePredicate found = this . p redicates. get ( predicate . getName ( ) ) ;
RoutePredicate found = this . routeP redicates. get ( predicate . getName ( ) ) ;
if ( found = = null ) {
throw new IllegalArgumentException ( "Unable to find RoutePredicate with name " + predicate . getName ( ) ) ;
}
@ -155,7 +168,7 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping {
@@ -155,7 +168,7 @@ public class RoutePredicateHandlerMapping extends AbstractHandlerMapping {
* @throws Exception if validation failed
* /
@SuppressWarnings ( "UnusedParameters" )
protected void validateRoute ( Route route , ServerWebExchange exchange ) throws Exception {
protected void validateRoute ( Route route , ServerWebExchange exchange ) {
}
}