diff --git a/pom.xml b/pom.xml index 71abb925b..b0b5dbfd8 100644 --- a/pom.xml +++ b/pom.xml @@ -51,6 +51,7 @@ UTF-8 UTF-8 + 8.3.0 1.0.6.RELEASE 17 1.6.1 @@ -97,6 +98,16 @@ spring-boot-devtools ${spring-boot.version} + + com.bucket4j + bucket4j-core + ${bucket4j.version} + + + com.bucket4j + bucket4j-caffeine + ${bucket4j.version} + io.projectreactor.tools blockhound diff --git a/spring-cloud-gateway-server-mvc/pom.xml b/spring-cloud-gateway-server-mvc/pom.xml index 297ba2e16..e0e06c071 100644 --- a/spring-cloud-gateway-server-mvc/pom.xml +++ b/spring-cloud-gateway-server-mvc/pom.xml @@ -66,6 +66,11 @@ spring-retry true + + com.bucket4j + bucket4j-core + true + org.springframework.boot spring-boot-starter-test @@ -78,5 +83,15 @@ httpclient5 test + + com.github.ben-manes.caffeine + caffeine + test + + + com.bucket4j + bucket4j-caffeine + test + \ No newline at end of file diff --git a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/Bucket4jFilterFunctions.java b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/Bucket4jFilterFunctions.java new file mode 100644 index 000000000..59ba2f1a1 --- /dev/null +++ b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/Bucket4jFilterFunctions.java @@ -0,0 +1,132 @@ +/* + * Copyright 2013-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.gateway.server.mvc.filter; + +import java.lang.reflect.Method; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collection; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.function.Function; + +import io.github.bucket4j.Bandwidth; +import io.github.bucket4j.BucketConfiguration; +import io.github.bucket4j.ConsumptionProbe; +import io.github.bucket4j.distributed.AsyncBucketProxy; +import io.github.bucket4j.distributed.proxy.AsyncProxyManager; + +import org.springframework.cloud.gateway.server.mvc.common.MvcUtils; +import org.springframework.http.HttpStatus; +import org.springframework.http.HttpStatusCode; +import org.springframework.web.servlet.function.HandlerFilterFunction; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; + +public abstract class Bucket4jFilterFunctions { + + private Bucket4jFilterFunctions() { + } + + public static HandlerFilterFunction rateLimit(long capacity, Duration period, + Function keyResolver) { + return rateLimit(c -> c.setCapacity(capacity).setPeriod(period).setKeyResolver(keyResolver)); + } + + public static HandlerFilterFunction rateLimit( + Consumer configConsumer) { + RateLimitConfig config = new RateLimitConfig(); + configConsumer.accept(config); + BucketConfiguration bucketConfiguration = BucketConfiguration.builder() + .addLimit(Bandwidth.simple(config.getCapacity(), config.getPeriod())).build(); + return (request, next) -> { + AsyncProxyManager proxyManager = MvcUtils.getApplicationContext(request).getBean(AsyncProxyManager.class); + AsyncBucketProxy bucket = proxyManager.builder().build(config.getKeyResolver().apply(request), + bucketConfiguration); + // TODO: configurable tokens + CompletableFuture bucketFuture = bucket.tryConsumeAndReturnRemaining(1); + // TODO: configurable timeout + ConsumptionProbe consumptionProbe = bucketFuture.get(); + boolean allowed = consumptionProbe.isConsumed(); + long remainingTokens = consumptionProbe.getRemainingTokens(); + if (allowed) { + ServerResponse serverResponse = next.handle(request); + // TODO: configurable headers + serverResponse.headers().add("X-RateLimit-Remaining", String.valueOf(remainingTokens)); + return serverResponse; + } + return ServerResponse.status(config.getStatusCode()) + .header("X-RateLimit-Remaining", String.valueOf(remainingTokens)).build(); + }; + } + + public static class RateLimitConfig { + + long capacity; + + Duration period; + + Function keyResolver; + + HttpStatusCode statusCode = HttpStatus.TOO_MANY_REQUESTS; + + public long getCapacity() { + return capacity; + } + + public RateLimitConfig setCapacity(long capacity) { + this.capacity = capacity; + return this; + } + + public Duration getPeriod() { + return period; + } + + public RateLimitConfig setPeriod(Duration period) { + this.period = period; + return this; + } + + public Function getKeyResolver() { + return keyResolver; + } + + public RateLimitConfig setKeyResolver(Function keyResolver) { + this.keyResolver = keyResolver; + return this; + } + + public HttpStatusCode getStatusCode() { + return statusCode; + } + + public RateLimitConfig setStatusCode(HttpStatusCode statusCode) { + this.statusCode = statusCode; + return this; + } + + } + + static class FilterSupplier implements org.springframework.cloud.gateway.server.mvc.filter.FilterSupplier { + @Override + public Collection get() { + return Arrays.asList(Bucket4jFilterFunctions.class.getMethods()); + } + } + +} diff --git a/spring-cloud-gateway-server-mvc/src/main/resources/META-INF/spring.factories b/spring-cloud-gateway-server-mvc/src/main/resources/META-INF/spring.factories index 2c81ef435..75a9b178c 100644 --- a/spring-cloud-gateway-server-mvc/src/main/resources/META-INF/spring.factories +++ b/spring-cloud-gateway-server-mvc/src/main/resources/META-INF/spring.factories @@ -1,4 +1,5 @@ org.springframework.cloud.gateway.server.mvc.filter.FilterSupplier=\ + org.springframework.cloud.gateway.server.mvc.filter.Bucket4jFilterFunctions.FilterSupplier,\ org.springframework.cloud.gateway.server.mvc.filter.FilterFunctionsFilterSupplier,\ org.springframework.cloud.gateway.server.mvc.filter.CircuitBreakerFilterFunctionsFilterSupplier diff --git a/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/ServerMvcIntegrationTests.java b/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/ServerMvcIntegrationTests.java index adf71e0d5..dadd01bfc 100644 --- a/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/ServerMvcIntegrationTests.java +++ b/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/ServerMvcIntegrationTests.java @@ -17,11 +17,16 @@ package org.springframework.cloud.gateway.server.mvc; import java.net.URI; +import java.time.Duration; import java.util.Collections; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; +import com.github.benmanes.caffeine.cache.Caffeine; +import io.github.bucket4j.caffeine.CaffeineProxyManager; +import io.github.bucket4j.distributed.proxy.AsyncProxyManager; +import io.github.bucket4j.distributed.remote.RemoteBucketState; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.junit.jupiter.api.Test; @@ -50,6 +55,7 @@ import org.springframework.web.servlet.function.ServerRequest; import org.springframework.web.servlet.function.ServerResponse; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.cloud.gateway.server.mvc.filter.Bucket4jFilterFunctions.rateLimit; import static org.springframework.cloud.gateway.server.mvc.filter.CircuitBreakerFilterFunctions.circuitBreaker; import static org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions.addRequestHeader; import static org.springframework.cloud.gateway.server.mvc.filter.FilterFunctions.addRequestParameter; @@ -208,6 +214,12 @@ public class ServerMvcIntegrationTests { restClient.get().uri("/retry?key=get").exchange().expectStatus().isOk().expectBody(String.class).isEqualTo("3"); } + @Test + public void rateLimitWorks() { + restClient.get().uri("/anything/ratelimit").exchange().expectStatus().isOk(); + restClient.get().uri("/anything/ratelimit").exchange().expectStatus().isEqualTo(HttpStatus.TOO_MANY_REQUESTS); + } + @SpringBootConfiguration @EnableAutoConfiguration @LoadBalancerClient(name = "testservice", configuration = TestLoadBalancerConfig.class) @@ -223,6 +235,12 @@ public class ServerMvcIntegrationTests { return new RetryController(); } + @Bean + public AsyncProxyManager caffeineProxyManager() { + Caffeine builder = (Caffeine) Caffeine.newBuilder().maximumSize(100); + return new CaffeineProxyManager<>(builder, Duration.ofMinutes(1)).asAsync(); + } + @Bean public RouterFunction nonGatewayRouterFunctions(TestHandler testHandler) { return route(GET("/hello"), testHandler::hello); @@ -365,6 +383,19 @@ public class ServerMvcIntegrationTests { // @formatter:on } + @Bean + public RouterFunction gatewayRouterFunctionsRateLimit() { + // @formatter:off + return route(GET("/anything/ratelimit"), http()) + .filter(new LocalServerPortUriResolver()) + //.filter(rateLimit(1, Duration.ofMinutes(1), request -> "ratelimittest1min")) + .filter(rateLimit(c -> c.setCapacity(1) + .setPeriod(Duration.ofMinutes(1)) + .setKeyResolver(request -> "ratelimitttest1min"))) + .filter(prefixPath("/httpbin")); + // @formatter:on + } + } @RestController