@ -1,14 +1,20 @@
@@ -1,14 +1,20 @@
package org.springframework.cloud.gateway.filter.ratelimit ;
import org.apache.commons.logging.Log ;
import org.apache.commons.logging.LogFactory ;
import org.springframework.data.redis.core.StringRedisTemplate ;
import org.springframework.data.redis.core.script.RedisScript ;
import java.time.Duration ;
import java.time.Instant ;
import java.util.ArrayList ;
import java.util.Arrays ;
import java.util.HashMap ;
import java.util.List ;
import org.apache.commons.logging.Log ;
import org.apache.commons.logging.LogFactory ;
import org.springframework.data.redis.core.ReactiveRedisTemplate ;
import reactor.core.publisher.Flux ;
import reactor.core.publisher.Mono ;
import reactor.util.function.Tuple2 ;
/ * *
* See https : //stripe.com/blog/rate-limiters and
* https : //gist.github.com/ptarjan/e38f45f2dfe601419ca3af937fff574d#file-1-check_request_rate_limiter-rb-L11-L34
@ -18,12 +24,10 @@ import java.util.List;
@@ -18,12 +24,10 @@ import java.util.List;
public class RedisRateLimiter implements RateLimiter {
private Log log = LogFactory . getLog ( getClass ( ) ) ;
private final StringRedisTemplate redisTemplate ;
private final RedisScript < List > script ;
private final ReactiveRedisTemplate < Object , Object > redisTemplate ;
public RedisRateLimiter ( StringRedisTemplate redisTemplate , RedisScript < Lis t> sc ript) {
public RedisRateLimiter ( ReactiveRedisTemplate < Object , Objec t> red isTem pla te ) {
this . redisTemplate = redisTemplate ;
this . script = script ;
}
/ * *
@ -37,29 +41,77 @@ public class RedisRateLimiter implements RateLimiter {
@@ -37,29 +41,77 @@ public class RedisRateLimiter implements RateLimiter {
@Override
//TODO: signature? params (tuple?).
//TODO: change to Mono<?>
public Response isAllowed ( String id , int replenishRate , int burstCapacity ) {
public Response isAllowed ( String id , long replenishRate , long burstCapacity ) {
try {
// Make a unique key per user.
String prefix = "request_rate_limiter." + id ;
String key = "request_rate_limiter." + id ;
// You need two Redis keys for Token Bucket.
List < String > keys = Arrays . asList ( prefix + ".tokens" , prefix + ".timestamp" ) ;
// String tokensKey = key + ".tokens";
// String timestampKey = key + ".timestamp";
// The arguments to the LUA script. time() returns unixtime in seconds.
Object [ ] args = new String [ ] { replenishRate + "" , burstCapacity + "" , Instant . now ( ) . getEpochSecond ( ) + "" , "1" } ;
// allowed, tokens_left = redis.eval(SCRIPT, keys, args)
List results = this . redisTemplate . execute ( this . script , keys , args ) ;
long now = Instant . now ( ) . getEpochSecond ( ) ;
int requested = 1 ;
boolean allowed = new Long ( 1L ) . equals ( results . get ( 0 ) ) ;
Long tokensLeft = ( Long ) results . get ( 1 ) ;
double fillTime = ( double ) burstCapacity / ( double ) replenishRate ;
int ttl = ( int ) Math . floor ( fillTime * 2 ) ;
Response response = new Response ( allowed , tokensLeft ) ;
Mono < Boolean > booleanMono = this . redisTemplate . hasKey ( key ) ;
Boolean hasKey = booleanMono . block ( ) ;
if ( log . isDebugEnabled ( ) ) {
log . debug ( "response: " + response ) ;
Mono < List < Object > > valuesMono ;
if ( hasKey ) {
valuesMono = this . redisTemplate . opsForHash ( ) . multiGet ( key , Arrays . asList ( "tokens" , "timestamp" ) ) ;
} else {
valuesMono = Mono . just ( new ArrayList < > ( ) ) ;
}
return response ;
Mono < Response > responseMono = valuesMono . map ( objects - > {
Long lastTokens = null ;
if ( objects . size ( ) > = 1 ) {
lastTokens = ( Long ) objects . get ( 0 ) ;
}
if ( lastTokens = = null ) {
lastTokens = burstCapacity ;
}
Long lastRefreshed = null ;
if ( objects . size ( ) > = 2 ) {
lastRefreshed = ( Long ) objects . get ( 1 ) ;
}
if ( lastRefreshed = = null ) {
lastRefreshed = 0L ;
}
long delta = Math . max ( 0 , ( now - lastRefreshed ) ) ;
long filledTokens = Math . min ( burstCapacity , lastTokens + ( delta * replenishRate ) ) ;
boolean allowed = filledTokens > = requested ;
long newTokens = filledTokens ;
if ( allowed ) {
newTokens = filledTokens - requested ;
}
HashMap < Object , Object > updated = new HashMap < > ( ) ;
updated . put ( "tokens" , newTokens ) ;
updated . put ( "timestamp" , now ) ;
Mono < Boolean > putAllMono = this . redisTemplate . opsForHash ( ) . putAll ( key , updated ) ;
Mono < Boolean > expireMono = this . redisTemplate . expire ( key , Duration . ofSeconds ( ttl ) ) ;
Flux < Tuple2 < Boolean , Boolean > > zip = Flux . zip ( putAllMono , expireMono ) ;
Tuple2 < Boolean , Boolean > objects1 = zip . blockLast ( ) ;
Response response = new Response ( allowed , newTokens ) ;
if ( log . isDebugEnabled ( ) ) {
log . debug ( "response: " + response ) ;
}
return response ;
} ) ;
return responseMono . block ( ) ;
} catch ( Exception e ) {
/ * We don ' t want a hard dependency on Redis to allow traffic .