Browse Source

Adds KeyResolver interface

pull/41/head
Spencer Gibb 8 years ago
parent
commit
7d1a1e09de
No known key found for this signature in database
GPG Key ID: 7788A47380690861
  1. 3
      spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/config/GatewayAutoConfiguration.java
  2. 53
      spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/factory/RequestRateLimiterWebFilterFactory.java
  3. 12
      spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/ratelimit/KeyResolver.java
  4. 89
      spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/factory/RequestRateLimiterWebFilterFactoryTests.java

3
spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/config/GatewayAutoConfiguration.java

@ -50,6 +50,7 @@ import org.springframework.cloud.gateway.filter.factory.SetPathWebFilterFactory; @@ -50,6 +50,7 @@ import org.springframework.cloud.gateway.filter.factory.SetPathWebFilterFactory;
import org.springframework.cloud.gateway.filter.factory.SetResponseHeaderWebFilterFactory;
import org.springframework.cloud.gateway.filter.factory.SetStatusWebFilterFactory;
import org.springframework.cloud.gateway.filter.factory.WebFilterFactory;
import org.springframework.cloud.gateway.filter.ratelimit.KeyResolver;
import org.springframework.cloud.gateway.filter.ratelimit.RateLimiter;
import org.springframework.cloud.gateway.filter.ratelimit.RedisRateLimiter;
import org.springframework.cloud.gateway.handler.FilteringWebHandler;
@ -294,7 +295,7 @@ public class GatewayAutoConfiguration { @@ -294,7 +295,7 @@ public class GatewayAutoConfiguration {
}
@Bean
@ConditionalOnBean(RateLimiter.class)
@ConditionalOnBean({RateLimiter.class, KeyResolver.class})
public RequestRateLimiterWebFilterFactory requestRateLimiterWebFilterFactory(RateLimiter rateLimiter) {
return new RequestRateLimiterWebFilterFactory(rateLimiter);
}

53
spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/factory/RequestRateLimiterWebFilterFactory.java

@ -17,46 +17,71 @@ @@ -17,46 +17,71 @@
package org.springframework.cloud.gateway.filter.factory;
import org.springframework.beans.BeansException;
import org.springframework.cloud.gateway.filter.ratelimit.KeyResolver;
import org.springframework.cloud.gateway.filter.ratelimit.RateLimiter;
import org.springframework.cloud.gateway.filter.ratelimit.RateLimiter.Response;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.http.HttpStatus;
import org.springframework.tuple.Tuple;
import org.springframework.web.server.WebFilter;
import java.util.Arrays;
import java.util.List;
/**
* User Request Rate Limiter filter.
* See https://stripe.com/blog/rate-limiters and
*/
public class RequestRateLimiterWebFilterFactory implements WebFilterFactory {
public class RequestRateLimiterWebFilterFactory implements WebFilterFactory, ApplicationContextAware {
public static final String REPLENISH_RATE_KEY = "replenishRate";
public static final String BURST_CAPACITY_KEY = "burstCapacity";
public static final String KEY_RESOLVER_NAME_KEY = "keyResolverName";
private final RateLimiter rateLimiter;
private ApplicationContext context;
public RequestRateLimiterWebFilterFactory(RateLimiter rateLimiter) {
this.rateLimiter = rateLimiter;
}
@Override
public void setApplicationContext(ApplicationContext context) throws BeansException {
this.context = context;
}
@Override
public List<String> argNames() {
return Arrays.asList(REPLENISH_RATE_KEY, BURST_CAPACITY_KEY, KEY_RESOLVER_NAME_KEY);
}
@SuppressWarnings("unchecked")
@Override
public WebFilter apply(Tuple args) {
// How many requests per second do you want a user to be allowed to do?
int replenishRate = 100;
int replenishRate = args.getInt(REPLENISH_RATE_KEY);
// How much bursting do you want to allow?
int capacity = 5 * replenishRate;
int capacity = args.getInt(BURST_CAPACITY_KEY);
String beanName = args.getString(KEY_RESOLVER_NAME_KEY);
KeyResolver keyResolver = this.context.getBean(beanName, KeyResolver.class);
return (exchange, chain) -> {
// exchange.getPrincipal().flatMap(principal -> {})
//TODO: get user from request, maybe a KeyResolutionStrategy.resolve(exchange). Lookup strategy bean via arg
Response response = rateLimiter.isAllowed("me", replenishRate, capacity);
return (exchange, chain) ->
keyResolver.resolve(exchange).flatMap(key -> {
Response response = rateLimiter.isAllowed(key, replenishRate, capacity);
//TODO: set some headers for rate, tokens left
//TODO: set some headers for rate, tokens left
if (response.isAllowed()) {
return chain.filter(exchange);
}
exchange.getResponse().setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
return exchange.getResponse().setComplete();
};
if (response.isAllowed()) {
return chain.filter(exchange);
}
exchange.getResponse().setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
return exchange.getResponse().setComplete();
});
}
}

12
spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/ratelimit/KeyResolver.java

@ -0,0 +1,12 @@ @@ -0,0 +1,12 @@
package org.springframework.cloud.gateway.filter.ratelimit;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
/**
* @author Spencer Gibb
*/
//TODO: KeyResolver for exchange.getPrincipal().flatMap(principal -> {})
public interface KeyResolver {
Mono<String> resolve(ServerWebExchange exchange);
}

89
spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/filter/factory/RequestRateLimiterWebFilterFactoryTests.java

@ -2,15 +2,34 @@ package org.springframework.cloud.gateway.filter.factory; @@ -2,15 +2,34 @@ package org.springframework.cloud.gateway.filter.factory;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.cloud.gateway.filter.ratelimit.KeyResolver;
import org.springframework.cloud.gateway.filter.ratelimit.RateLimiter;
import org.springframework.cloud.gateway.filter.ratelimit.RateLimiter.Response;
import org.springframework.cloud.gateway.test.BaseWebClientTests;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.http.HttpStatus;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.MockServerWebExchange;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.tuple.Tuple;
import org.springframework.web.server.WebFilterChain;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.when;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT;
import static org.springframework.cloud.gateway.filter.factory.RequestRateLimiterWebFilterFactory.BURST_CAPACITY_KEY;
import static org.springframework.cloud.gateway.filter.factory.RequestRateLimiterWebFilterFactory.KEY_RESOLVER_NAME_KEY;
import static org.springframework.cloud.gateway.filter.factory.RequestRateLimiterWebFilterFactory.REPLENISH_RATE_KEY;
import static org.springframework.tuple.TupleBuilder.tuple;
import reactor.core.publisher.Mono;
/**
* see https://gist.github.com/ptarjan/e38f45f2dfe601419ca3af937fff574d#file-1-check_request_rate_limiter-rb-L36-L62
@ -21,57 +40,61 @@ import static org.springframework.boot.test.context.SpringBootTest.WebEnvironmen @@ -21,57 +40,61 @@ import static org.springframework.boot.test.context.SpringBootTest.WebEnvironmen
@DirtiesContext
public class RequestRateLimiterWebFilterFactoryTests extends BaseWebClientTests {
/*@Autowired
private StringRedisTemplate redisTemplate;
@Autowired
private RedisScript<List> script;*/
private RequestRateLimiterWebFilterFactory filterFactory;
@MockBean
private RateLimiter rateLimiter;
@MockBean
private WebFilterChain filterChain;
@Test
public void requestRateLimiterWebFilterFactoryWorks() throws Exception {
/*String id = UUID.randomUUID().toString();
public void allowedWorks() throws Exception {
assertFilterFactory("resolver1", "allowedkey", true, HttpStatus.OK);
}
RequestRateLimiterWebFilterFactory filterFactory = new RequestRateLimiterWebFilterFactory(this.redisTemplate, this.script);
@Test
public void notAllowedWorks() throws Exception {
assertFilterFactory("resolver2", "notallowedkey", false, HttpStatus.TOO_MANY_REQUESTS);
}
private void assertFilterFactory(String keyResolverName, String key, boolean allowed, HttpStatus expectedStatus) {
int replenishRate = 10;
int capacity = 2 * replenishRate;
int burstCapacity = 2 * replenishRate;
Tuple args = tuple().of(REPLENISH_RATE_KEY, replenishRate,
BURST_CAPACITY_KEY, burstCapacity,
KEY_RESOLVER_NAME_KEY, keyResolverName);
// Bursts work
for (int i = 0; i < capacity; i++) {
boolean allowed = filterFactory.isAllowed(replenishRate, capacity, id);
assertThat(allowed).isTrue();
}
boolean allowed = filterFactory.isAllowed(replenishRate, capacity, id);
assertThat(allowed).isFalse();
when(rateLimiter.isAllowed(key, replenishRate, burstCapacity))
.thenReturn(new Response(allowed, 1));
Thread.sleep(1000);
// # After the burst is done, check the steady state
for (int i = 0; i < replenishRate; i++) {
allowed = filterFactory.isAllowed(replenishRate, capacity, id);
assertThat(allowed).isTrue();
}
MockServerHttpRequest request = MockServerHttpRequest.get("/").build();
MockServerWebExchange exchange = new MockServerWebExchange(request);
exchange.getResponse().setStatusCode(HttpStatus.OK);
when(this.filterChain.filter(exchange)).thenReturn(Mono.empty());
Mono<Void> response = filterFactory.apply(args).filter(exchange, this.filterChain);
response.subscribe(aVoid -> {
assertThat(exchange.getResponse().getStatusCode()).isEqualTo(expectedStatus);
});
allowed = filterFactory.isAllowed(replenishRate, capacity, id);
assertThat(allowed).isFalse();*/
}
@EnableAutoConfiguration
@SpringBootConfiguration
@Import(BaseWebClientTests.DefaultTestConfig.class)
public static class TestConfig {
/*@Bean
public RedisScript<List> requestRateLimiterScript() {
DefaultRedisScript<List> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("META-INF/scripts/request_rate_limiter.lua")));
redisScript.setResultType(List.class);
return redisScript;
@Bean
KeyResolver resolver1() {
return exchange -> Mono.just("allowedkey");
}
@Bean
public RequestRateLimiterWebFilterFactory requestRateLimiterWebFilterFactory(StringRedisTemplate redisTemplate) {
return new RequestRateLimiterWebFilterFactory(redisTemplate, requestRateLimiterScript());
}*/
KeyResolver resolver2() {
return exchange -> Mono.just("notallowedkey");
}
}
}

Loading…
Cancel
Save