Spencer Gibb
8 years ago
5 changed files with 230 additions and 0 deletions
@ -0,0 +1,95 @@
@@ -0,0 +1,95 @@
|
||||
/* |
||||
* Copyright 2013-2017 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
* |
||||
*/ |
||||
|
||||
package org.springframework.cloud.gateway.filter.factory; |
||||
|
||||
import java.time.Instant; |
||||
import java.util.Arrays; |
||||
import java.util.List; |
||||
|
||||
import org.apache.commons.logging.Log; |
||||
import org.apache.commons.logging.LogFactory; |
||||
import org.springframework.data.redis.core.StringRedisTemplate; |
||||
import org.springframework.data.redis.core.script.RedisScript; |
||||
import org.springframework.http.HttpStatus; |
||||
import org.springframework.tuple.Tuple; |
||||
import org.springframework.web.server.WebFilter; |
||||
|
||||
/** |
||||
* Sample User Request Rate Throttle filter. |
||||
* See https://stripe.com/blog/rate-limiters and
|
||||
* https://gist.github.com/ptarjan/e38f45f2dfe601419ca3af937fff574d#file-1-check_request_rate_limiter-rb-L11-L34
|
||||
*/ |
||||
public class RequestRateLimiterWebFilterFactory implements WebFilterFactory { |
||||
private Log log = LogFactory.getLog(getClass()); |
||||
|
||||
private final StringRedisTemplate redisTemplate; |
||||
private final RedisScript<List> script; |
||||
|
||||
public RequestRateLimiterWebFilterFactory(StringRedisTemplate redisTemplate, RedisScript<List> script) { |
||||
this.redisTemplate = redisTemplate; |
||||
this.script = script; |
||||
} |
||||
|
||||
@SuppressWarnings("unchecked") |
||||
@Override |
||||
public WebFilter apply(Tuple args) { |
||||
// How many requests per second do you want a user to be allowed to do?
|
||||
int replenishRate = 100; |
||||
|
||||
// How much bursting do you want to allow?
|
||||
int capacity = 5 * replenishRate; |
||||
|
||||
return (exchange, chain) -> { |
||||
boolean allowed = isAllowed(replenishRate, capacity, "me"); //TODO: get user from request
|
||||
|
||||
if (allowed) { |
||||
return chain.filter(exchange); |
||||
} |
||||
exchange.getResponse().setStatusCode(HttpStatus.TOO_MANY_REQUESTS); |
||||
return exchange.getResponse().setComplete(); |
||||
}; |
||||
} |
||||
|
||||
/* for testing */ boolean isAllowed(int replenishRate, int capacity, String id) { |
||||
boolean allowed = false; |
||||
|
||||
try { |
||||
// # Make a unique key per user.
|
||||
String prefix = "request_rate_limiter." + id; |
||||
|
||||
// # You need two Redis keys for Token Bucket.
|
||||
List<String> keys = Arrays.asList(prefix + ".tokens", prefix + ".timestamp"); |
||||
|
||||
// The arguments to the LUA script. time() returns unixtime in seconds.
|
||||
String[] args = new String[]{ replenishRate+"", capacity+"", Instant.now().getEpochSecond()+"", "1"}; |
||||
// allowed, tokens_left = redis.eval(SCRIPT, keys, args)
|
||||
List results = this.redisTemplate.execute(this.script, keys, args); |
||||
|
||||
allowed = new Long(1L).equals(results.get(0)); |
||||
Long tokensLeft = (Long) results.get(1); |
||||
|
||||
if (log.isDebugEnabled()) { |
||||
log.debug("isAllowed("+id+")=" + allowed + ", tokensLeft: "+tokensLeft); |
||||
} |
||||
|
||||
} catch (Exception e) { |
||||
log.error("Error determining if user allowed from redis", e); |
||||
} |
||||
return allowed; |
||||
} |
||||
} |
@ -0,0 +1,34 @@
@@ -0,0 +1,34 @@
|
||||
local tokens_key = KEYS[1] |
||||
local timestamp_key = KEYS[2] |
||||
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key) |
||||
|
||||
local rate = tonumber(ARGV[1]) |
||||
local capacity = tonumber(ARGV[2]) |
||||
local now = tonumber(ARGV[3]) |
||||
local requested = tonumber(ARGV[4]) |
||||
|
||||
local fill_time = capacity/rate |
||||
local ttl = math.floor(fill_time*2) |
||||
|
||||
local last_tokens = tonumber(redis.call("get", tokens_key)) |
||||
if last_tokens == nil then |
||||
last_tokens = capacity |
||||
end |
||||
|
||||
local last_refreshed = tonumber(redis.call("get", timestamp_key)) |
||||
if last_refreshed == nil then |
||||
last_refreshed = 0 |
||||
end |
||||
|
||||
local delta = math.max(0, now-last_refreshed) |
||||
local filled_tokens = math.min(capacity, last_tokens+(delta*rate)) |
||||
local allowed = filled_tokens >= requested |
||||
local new_tokens = filled_tokens |
||||
if allowed then |
||||
new_tokens = filled_tokens - requested |
||||
end |
||||
|
||||
redis.call("setex", tokens_key, ttl, new_tokens) |
||||
redis.call("setex", timestamp_key, ttl, now) |
||||
|
||||
return { allowed, new_tokens } |
@ -0,0 +1,89 @@
@@ -0,0 +1,89 @@
|
||||
package org.springframework.cloud.gateway.filter.factory; |
||||
|
||||
import java.util.List; |
||||
import java.util.UUID; |
||||
|
||||
import org.junit.Test; |
||||
import org.junit.runner.RunWith; |
||||
import org.springframework.beans.factory.annotation.Autowired; |
||||
import org.springframework.boot.SpringBootConfiguration; |
||||
import org.springframework.boot.autoconfigure.EnableAutoConfiguration; |
||||
import org.springframework.boot.test.context.SpringBootTest; |
||||
import org.springframework.cloud.gateway.test.BaseWebClientTests; |
||||
import org.springframework.context.annotation.Bean; |
||||
import org.springframework.context.annotation.Import; |
||||
import org.springframework.core.io.ClassPathResource; |
||||
import org.springframework.data.redis.core.StringRedisTemplate; |
||||
import org.springframework.data.redis.core.script.DefaultRedisScript; |
||||
import org.springframework.data.redis.core.script.RedisScript; |
||||
import org.springframework.scripting.support.ResourceScriptSource; |
||||
import org.springframework.test.annotation.DirtiesContext; |
||||
import org.springframework.test.context.junit4.SpringRunner; |
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat; |
||||
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT; |
||||
|
||||
/** |
||||
* see https://gist.github.com/ptarjan/e38f45f2dfe601419ca3af937fff574d#file-1-check_request_rate_limiter-rb-L36-L62
|
||||
* @author Spencer Gibb |
||||
*/ |
||||
@RunWith(SpringRunner.class) |
||||
@SpringBootTest(properties = "logging.level.org.springframework.cloud.gateway.filter.factory=DEBUG", |
||||
webEnvironment = RANDOM_PORT) |
||||
@DirtiesContext |
||||
public class RequestRateLimiterWebFilterFactoryTests extends BaseWebClientTests { |
||||
|
||||
@Autowired |
||||
private StringRedisTemplate redisTemplate; |
||||
|
||||
@Autowired |
||||
private RedisScript<List> script; |
||||
|
||||
@Test |
||||
public void requestRateLimiterWebFilterFactoryWorks() throws Exception { |
||||
String id = UUID.randomUUID().toString(); |
||||
|
||||
RequestRateLimiterWebFilterFactory filterFactory = new RequestRateLimiterWebFilterFactory(this.redisTemplate, this.script); |
||||
|
||||
int replenishRate = 10; |
||||
int capacity = 5 * replenishRate; |
||||
|
||||
// Bursts work
|
||||
for (int i = 0; i < capacity; i++) { |
||||
boolean allowed = filterFactory.isAllowed(replenishRate, capacity, id); |
||||
assertThat(allowed).isTrue(); |
||||
} |
||||
|
||||
boolean allowed = filterFactory.isAllowed(replenishRate, capacity, id); |
||||
assertThat(allowed).isFalse(); |
||||
|
||||
Thread.sleep(1000); |
||||
|
||||
// # After the burst is done, check the steady state
|
||||
for (int i = 0; i < replenishRate; i++) { |
||||
allowed = filterFactory.isAllowed(replenishRate, capacity, id); |
||||
assertThat(allowed).isTrue(); |
||||
} |
||||
|
||||
allowed = filterFactory.isAllowed(replenishRate, capacity, id); |
||||
assertThat(allowed).isFalse(); |
||||
} |
||||
|
||||
@EnableAutoConfiguration |
||||
@SpringBootConfiguration |
||||
@Import(BaseWebClientTests.DefaultTestConfig.class) |
||||
public static class TestConfig { |
||||
@Bean |
||||
public RedisScript<List> requestRateLimiterScript() { |
||||
DefaultRedisScript<List> redisScript = new DefaultRedisScript<>(); |
||||
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("META-INF/scripts/request_rate_limiter.lua"))); |
||||
redisScript.setResultType(List.class); |
||||
return redisScript; |
||||
} |
||||
|
||||
@Bean |
||||
public RequestRateLimiterWebFilterFactory requestRateLimiterWebFilterFactory(StringRedisTemplate redisTemplate) { |
||||
return new RequestRateLimiterWebFilterFactory(redisTemplate, requestRateLimiterScript()); |
||||
} |
||||
} |
||||
} |
Loading…
Reference in new issue