Browse Source

Support loadbalancer lifecycle (#2066)

* Add LoadBalancer lifecycle callbacks.

* Switch test to junit jupiter.

* Add more tests.

* Update docs.

* Adjust to changes in commons.
pull/2083/head
Olga Maciaszek-Sharma 4 years ago committed by GitHub
parent
commit
c19af741aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      docs/src/main/asciidoc/spring-cloud-gateway.adoc
  2. 42
      spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/ReactiveLoadBalancerClientFilter.java
  3. 12
      spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/support/ServerWebExchangeUtils.java
  4. 147
      spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/ReactiveLoadBalancerClientFilterTests.java

4
docs/src/main/asciidoc/spring-cloud-gateway.adoc

@ -1803,9 +1803,7 @@ However, if `GATEWAY_SCHEME_PREFIX_ATTR` is specified for the @@ -1803,9 +1803,7 @@ However, if `GATEWAY_SCHEME_PREFIX_ATTR` is specified for the
route in the Gateway configuration, the prefix is stripped and the resulting scheme from the
route URL overrides the `ServiceInstance` configuration.
WARNING: `LoadBalancerClientFilter` uses a blocking ribbon `LoadBalancerClient` under the hood.
We suggest you use <<reactive-loadbalancer-client-filter,`ReactiveLoadBalancerClientFilter` instead>>.
You can switch to it by setting the value of the `spring.cloud.loadbalancer.ribbon.enabled` to `false`.
TIP: Gateway supports all the LoadBalancer features. You can read more about them in the https://docs.spring.io/spring-cloud-commons/docs/current/reference/html/#spring-cloud-loadbalancer[Spring Cloud Commons documentation].
[[reactive-loadbalancer-client-filter]]
=== The `ReactiveLoadBalancerClientFilter`

42
spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/ReactiveLoadBalancerClientFilter.java

@ -18,16 +18,21 @@ package org.springframework.cloud.gateway.filter; @@ -18,16 +18,21 @@ package org.springframework.cloud.gateway.filter;
import java.net.URI;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Mono;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.CompletionContext;
import org.springframework.cloud.client.loadbalancer.DefaultRequest;
import org.springframework.cloud.client.loadbalancer.LoadBalancerLifecycle;
import org.springframework.cloud.client.loadbalancer.LoadBalancerLifecycleValidator;
import org.springframework.cloud.client.loadbalancer.LoadBalancerUriTools;
import org.springframework.cloud.client.loadbalancer.RequestData;
import org.springframework.cloud.client.loadbalancer.RequestDataContext;
import org.springframework.cloud.client.loadbalancer.Response;
import org.springframework.cloud.client.loadbalancer.ServerHttpRequestContext;
import org.springframework.cloud.client.loadbalancer.reactive.LoadBalancerProperties;
import org.springframework.cloud.gateway.config.GatewayLoadBalancerProperties;
import org.springframework.cloud.gateway.support.DelegatingServiceInstance;
@ -36,8 +41,12 @@ import org.springframework.cloud.loadbalancer.core.ReactorLoadBalancer; @@ -36,8 +41,12 @@ import org.springframework.cloud.loadbalancer.core.ReactorLoadBalancer;
import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;
import org.springframework.cloud.loadbalancer.support.LoadBalancerClientFactory;
import org.springframework.core.Ordered;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.server.ServerWebExchange;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_LOADBALANCER_RESPONSE_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_SCHEME_PREFIX_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.addOriginalRequestUrl;
@ -50,7 +59,7 @@ import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.a @@ -50,7 +59,7 @@ import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.a
* @author Tim Ysewyn
* @author Olga Maciaszek-Sharma
*/
@SuppressWarnings("rawtypes")
@SuppressWarnings({ "rawtypes", "unchecked" })
public class ReactiveLoadBalancerClientFilter implements GlobalFilter, Ordered {
private static final Log log = LogFactory.getLog(ReactiveLoadBalancerClientFilter.class);
@ -79,7 +88,6 @@ public class ReactiveLoadBalancerClientFilter implements GlobalFilter, Ordered { @@ -79,7 +88,6 @@ public class ReactiveLoadBalancerClientFilter implements GlobalFilter, Ordered {
}
@Override
@SuppressWarnings("Duplicates")
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
URI url = exchange.getAttribute(GATEWAY_REQUEST_URL_ATTR);
String schemePrefix = exchange.getAttribute(GATEWAY_SCHEME_PREFIX_ATTR);
@ -93,9 +101,16 @@ public class ReactiveLoadBalancerClientFilter implements GlobalFilter, Ordered { @@ -93,9 +101,16 @@ public class ReactiveLoadBalancerClientFilter implements GlobalFilter, Ordered {
log.trace(ReactiveLoadBalancerClientFilter.class.getSimpleName() + " url before: " + url);
}
return choose(exchange).doOnNext(response -> {
URI requestUri = exchange.getAttribute(GATEWAY_REQUEST_URL_ATTR);
String serviceId = requestUri.getHost();
Set<LoadBalancerLifecycle> supportedLifecycleProcessors = LoadBalancerLifecycleValidator
.getSupportedLifecycleProcessors(clientFactory.getInstances(serviceId, LoadBalancerLifecycle.class),
ServerHttpRequest.class, ServerHttpResponse.class, ServiceInstance.class);
return choose(exchange, serviceId, supportedLifecycleProcessors).doOnNext(response -> {
if (!response.hasServer()) {
supportedLifecycleProcessors.forEach(lifecycle -> lifecycle
.onComplete(new CompletionContext<>(CompletionContext.Status.DISCARD, response)));
throw NotFoundException.create(properties.isUse404(), "Unable to find instance for " + url.getHost());
}
@ -119,23 +134,30 @@ public class ReactiveLoadBalancerClientFilter implements GlobalFilter, Ordered { @@ -119,23 +134,30 @@ public class ReactiveLoadBalancerClientFilter implements GlobalFilter, Ordered {
log.trace("LoadBalancerClientFilter url chosen: " + requestUrl);
}
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, requestUrl);
}).then(chain.filter(exchange));
exchange.getAttributes().put(GATEWAY_LOADBALANCER_RESPONSE_ATTR, response);
}).then(chain.filter(exchange))
.doOnError(throwable -> supportedLifecycleProcessors.forEach(lifecycle -> lifecycle.onComplete(
new CompletionContext<ClientResponse, ServiceInstance>(CompletionContext.Status.FAILED,
throwable, exchange.getAttribute(GATEWAY_LOADBALANCER_RESPONSE_ATTR)))))
.doOnSuccess(aVoid -> supportedLifecycleProcessors.forEach(
lifecycle -> lifecycle.onComplete(new CompletionContext<>(CompletionContext.Status.SUCCESS,
exchange.getAttribute(GATEWAY_LOADBALANCER_RESPONSE_ATTR), exchange.getResponse()))));
}
protected URI reconstructURI(ServiceInstance serviceInstance, URI original) {
return LoadBalancerUriTools.reconstructURI(serviceInstance, original);
}
private Mono<Response<ServiceInstance>> choose(ServerWebExchange exchange) {
URI uri = exchange.getAttribute(GATEWAY_REQUEST_URL_ATTR);
String serviceId = uri.getHost();
private Mono<Response<ServiceInstance>> choose(ServerWebExchange exchange, String serviceId,
Set<LoadBalancerLifecycle> supportedLifecycleProcessors) {
ReactorLoadBalancer<ServiceInstance> loadBalancer = this.clientFactory.getInstance(serviceId,
ReactorServiceInstanceLoadBalancer.class);
if (loadBalancer == null) {
throw new NotFoundException("No loadbalancer available for " + serviceId);
}
DefaultRequest<ServerHttpRequestContext> lbRequest = new DefaultRequest<>(new ServerHttpRequestContext(
exchange.getRequest(), getHint(serviceId, loadBalancerProperties.getHint())));
DefaultRequest<RequestDataContext> lbRequest = new DefaultRequest<>(new RequestDataContext(
new RequestData(exchange.getRequest()), getHint(serviceId, loadBalancerProperties.getHint())));
supportedLifecycleProcessors.forEach(lifecycle -> lifecycle.onStart(lbRequest));
return loadBalancer.choose(lbRequest);
}

12
spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/support/ServerWebExchangeUtils.java

@ -31,6 +31,7 @@ import org.apache.commons.logging.LogFactory; @@ -31,6 +31,7 @@ import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.cloud.client.loadbalancer.Response;
import org.springframework.cloud.gateway.filter.factory.GatewayFilterFactory;
import org.springframework.cloud.gateway.handler.AsyncPredicate;
import org.springframework.cloud.gateway.handler.predicate.RoutePredicateFactory;
@ -148,6 +149,11 @@ public final class ServerWebExchangeUtils { @@ -148,6 +149,11 @@ public final class ServerWebExchangeUtils {
*/
public static final String CACHED_REQUEST_BODY_ATTR = "cachedRequestBody";
/**
* Gateway LoadBalancer {@link Response} attribute name.
*/
public static final String GATEWAY_LOADBALANCER_RESPONSE_ATTR = qualify("gatewayLoadBalancerResponse");
private ServerWebExchangeUtils() {
throw new AssertionError("Must not instantiate utility class.");
}
@ -194,7 +200,7 @@ public final class ServerWebExchangeUtils { @@ -194,7 +200,7 @@ public final class ServerWebExchangeUtils {
return setResponseStatus(exchange, statusHolder.getHttpStatus());
}
if (statusHolder.getStatus() != null && exchange.getResponse() instanceof AbstractServerHttpResponse) { // non-standard
((AbstractServerHttpResponse) exchange.getResponse()).setStatusCodeValue(statusHolder.getStatus());
((AbstractServerHttpResponse) exchange.getResponse()).setRawStatusCode(statusHolder.getStatus());
return true;
}
return false;
@ -210,9 +216,9 @@ public final class ServerWebExchangeUtils { @@ -210,9 +216,9 @@ public final class ServerWebExchangeUtils {
UriComponentsBuilder.fromUri(uri).build(true);
return true;
}
catch (IllegalArgumentException ignore) {
catch (IllegalArgumentException ignored) {
if (log.isTraceEnabled()) {
log.trace("Error in containsEncodedParts", ignore);
log.trace("Error in containsEncodedParts", ignored);
}
}

147
spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/ReactiveLoadBalancerClientFilterTests.java

@ -22,20 +22,24 @@ import java.util.LinkedHashSet; @@ -22,20 +22,24 @@ import java.util.LinkedHashSet;
import java.util.Map;
import org.jetbrains.annotations.NotNull;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Mono;
import org.springframework.cloud.client.DefaultServiceInstance;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.CompletionContext;
import org.springframework.cloud.client.loadbalancer.DefaultResponse;
import org.springframework.cloud.client.loadbalancer.EmptyResponse;
import org.springframework.cloud.client.loadbalancer.LoadBalancerLifecycle;
import org.springframework.cloud.client.loadbalancer.Request;
import org.springframework.cloud.client.loadbalancer.ServerHttpRequestContext;
import org.springframework.cloud.client.loadbalancer.RequestDataContext;
import org.springframework.cloud.client.loadbalancer.Response;
import org.springframework.cloud.client.loadbalancer.reactive.LoadBalancerProperties;
import org.springframework.cloud.gateway.config.GatewayLoadBalancerProperties;
import org.springframework.cloud.gateway.support.NotFoundException;
@ -51,6 +55,7 @@ import org.springframework.web.server.ServerWebExchange; @@ -51,6 +55,7 @@ import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
@ -58,6 +63,7 @@ import static org.mockito.Mockito.verify; @@ -58,6 +63,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_LOADBALANCER_RESPONSE_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ORIGINAL_REQUEST_URL_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_SCHEME_PREFIX_ATTR;
@ -70,8 +76,8 @@ import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.G @@ -70,8 +76,8 @@ import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.G
* @author Olga Maciaszek-Sharma
*/
@SuppressWarnings("UnassignedFluxMonoInstance")
@RunWith(MockitoJUnitRunner.class)
public class ReactiveLoadBalancerClientFilterTests {
@ExtendWith(MockitoExtension.class)
class ReactiveLoadBalancerClientFilterTests {
private ServerWebExchange exchange;
@ -89,14 +95,14 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -89,14 +95,14 @@ public class ReactiveLoadBalancerClientFilterTests {
@InjectMocks
private ReactiveLoadBalancerClientFilter filter;
@Before
public void setup() {
@BeforeEach
void setup() {
properties = new GatewayLoadBalancerProperties();
exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/mypath").build());
}
@Test
public void shouldNotFilterWhenGatewayRequestUrlIsMissing() {
void shouldNotFilterWhenGatewayRequestUrlIsMissing() {
filter.filter(exchange, chain);
verify(chain).filter(exchange);
@ -105,7 +111,7 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -105,7 +111,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
@Test
public void shouldNotFilterWhenGatewayRequestUrlSchemeIsNotLb() {
void shouldNotFilterWhenGatewayRequestUrlSchemeIsNotLb() {
URI uri = UriComponentsBuilder.fromUriString("http://myservice").build().toUri();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, uri);
@ -116,17 +122,19 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -116,17 +122,19 @@ public class ReactiveLoadBalancerClientFilterTests {
verifyNoInteractions(clientFactory);
}
@Test(expected = NotFoundException.class)
public void shouldThrowExceptionWhenNoServiceInstanceIsFound() {
URI uri = UriComponentsBuilder.fromUriString("lb://myservice").build().toUri();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, uri);
@Test
void shouldThrowExceptionWhenNoServiceInstanceIsFound() {
assertThatExceptionOfType(NotFoundException.class).isThrownBy(() -> {
URI uri = UriComponentsBuilder.fromUriString("lb://myservice").build().toUri();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, uri);
filter.filter(exchange, chain).block();
filter.filter(exchange, chain).block();
});
}
@SuppressWarnings("unchecked")
@Test
public void shouldFilter() {
void shouldFilter() {
URI url = UriComponentsBuilder.fromUriString("lb://myservice").build().toUri();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, url);
@ -145,6 +153,8 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -145,6 +153,8 @@ public class ReactiveLoadBalancerClientFilterTests {
verify(clientFactory).getInstance("myservice", ReactorServiceInstanceLoadBalancer.class);
verify(clientFactory).getInstances("myservice", LoadBalancerLifecycle.class);
verifyNoMoreInteractions(clientFactory);
assertThat((URI) exchange.getAttribute(GATEWAY_REQUEST_URL_ATTR))
@ -155,7 +165,7 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -155,7 +165,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
@Test
public void happyPath() {
void happyPath() {
MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost/get?a=b").build();
URI lbUri = URI.create("lb://service1?a=b");
@ -165,7 +175,7 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -165,7 +175,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
@Test
public void noQueryParams() {
void noQueryParams() {
MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost/get").build();
ServerWebExchange webExchange = testFilter(request, URI.create("lb://service1"));
@ -174,7 +184,7 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -174,7 +184,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
@Test
public void encodedParameters() {
void encodedParameters() {
URI url = UriComponentsBuilder.fromUriString("http://localhost/get?a=b&c=d[]").buildAndExpand().encode()
.toUri();
@ -196,7 +206,7 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -196,7 +206,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
@Test
public void unencodedParameters() {
void unencodedParameters() {
URI url = URI.create("http://localhost/get?a=b&c=d[]");
MockServerHttpRequest request = MockServerHttpRequest.method(HttpMethod.GET, url).build();
@ -216,7 +226,7 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -216,7 +226,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
@Test
public void happyPathWithAttributeRatherThanScheme() {
void happyPathWithAttributeRatherThanScheme() {
MockServerHttpRequest request = MockServerHttpRequest.get("ws://localhost/get?a=b").build();
URI lbUri = URI.create("ws://service1?a=b");
@ -230,7 +240,7 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -230,7 +240,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
@Test
public void shouldNotFilterWhenGatewaySchemePrefixAttrIsNotLb() {
void shouldNotFilterWhenGatewaySchemePrefixAttrIsNotLb() {
URI uri = UriComponentsBuilder.fromUriString("http://myservice").build().toUri();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, uri);
exchange.getAttributes().put(GATEWAY_SCHEME_PREFIX_ATTR, "xx");
@ -243,7 +253,7 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -243,7 +253,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
@Test
public void shouldThrow4O4ExceptionWhenNoServiceInstanceIsFound() {
void shouldThrow4O4ExceptionWhenNoServiceInstanceIsFound() {
URI uri = UriComponentsBuilder.fromUriString("lb://service1").build().toUri();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, uri);
RoundRobinLoadBalancer loadBalancer = new RoundRobinLoadBalancer(
@ -263,7 +273,7 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -263,7 +273,7 @@ public class ReactiveLoadBalancerClientFilterTests {
@SuppressWarnings("unchecked")
@Test
public void shouldOverrideSchemeUsingIsSecure() {
void shouldOverrideSchemeUsingIsSecure() {
URI url = UriComponentsBuilder.fromUriString("lb://myservice").build().toUri();
ServerWebExchange exchange = MockServerWebExchange
.from(MockServerHttpRequest.get("https://localhost:9999/mypath").build());
@ -286,7 +296,7 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -286,7 +296,7 @@ public class ReactiveLoadBalancerClientFilterTests {
@SuppressWarnings({ "rawtypes" })
@Test
public void shouldPassRequestToLoadBalancer() {
void shouldPassRequestToLoadBalancer() {
String hint = "test";
when(loadBalancerProperties.getHint()).thenReturn(buildHints(hint));
MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost/get?a=b").build();
@ -298,17 +308,92 @@ public class ReactiveLoadBalancerClientFilterTests { @@ -298,17 +308,92 @@ public class ReactiveLoadBalancerClientFilterTests {
when(serverWebExchange.getRequest()).thenReturn(request);
RoundRobinLoadBalancer loadBalancer = mock(RoundRobinLoadBalancer.class);
when(loadBalancer.choose(any(Request.class))).thenReturn(Mono.just(
new DefaultResponse(new DefaultServiceInstance("myservice1", "myservice", "localhost", 8080, false))));
new DefaultResponse(new DefaultServiceInstance("myservice1", "service1", "localhost", 8080, false))));
when(clientFactory.getInstance("service1", ReactorServiceInstanceLoadBalancer.class)).thenReturn(loadBalancer);
when(chain.filter(any())).thenReturn(Mono.empty());
filter.filter(serverWebExchange, chain);
verify(loadBalancer)
.choose(argThat((Request passedRequest) -> ((ServerHttpRequestContext) passedRequest.getContext())
.getClientRequest().equals(request)
&& ((ServerHttpRequestContext) passedRequest.getContext()).getHint().equals(hint)));
verify(loadBalancer).choose(argThat((Request passedRequest) -> ((RequestDataContext) passedRequest.getContext())
.getClientRequest().getUrl().equals(request.getURI())
&& ((RequestDataContext) passedRequest.getContext()).getHint().equals(hint)));
}
@SuppressWarnings({ "unchecked", "rawtypes" })
@Test
void loadBalancerLifecycleCallbacksExecutedForSuccess() {
LoadBalancerLifecycle lifecycleProcessor = mock(LoadBalancerLifecycle.class);
ServiceInstance serviceInstance = new DefaultServiceInstance("myservice1", "myservice", "localhost", 8080,
false);
ServerWebExchange serverWebExchange = mockExchange(serviceInstance, lifecycleProcessor, false);
filter.filter(serverWebExchange, chain).subscribe();
verify(lifecycleProcessor).onStart(any(Request.class));
verify(lifecycleProcessor).onComplete(
argThat(completionContext -> CompletionContext.Status.SUCCESS.equals(completionContext.status())
&& completionContext.getLoadBalancerResponse().getServer().equals(serviceInstance)));
}
@SuppressWarnings({ "unchecked", "rawtypes" })
@Test
void loadBalancerLifecycleCallbacksExecutedForDiscard() {
LoadBalancerLifecycle lifecycleProcessor = mock(LoadBalancerLifecycle.class);
ServiceInstance serviceInstance = null;
ServerWebExchange serverWebExchange = mockExchange(serviceInstance, lifecycleProcessor, false);
filter.filter(serverWebExchange, chain).subscribe();
verify(lifecycleProcessor).onStart(any(Request.class));
verify(lifecycleProcessor).onComplete(
argThat(completionContext -> CompletionContext.Status.DISCARD.equals(completionContext.status())));
}
@SuppressWarnings({ "unchecked", "rawtypes" })
@Test
void loadBalancerLifecycleCallbacksExecutedForFailed() {
LoadBalancerLifecycle lifecycleProcessor = mock(LoadBalancerLifecycle.class);
ServiceInstance serviceInstance = null;
ServerWebExchange serverWebExchange = mockExchange(serviceInstance, lifecycleProcessor, true);
filter.filter(serverWebExchange, chain).subscribe();
verify(lifecycleProcessor).onStart(any(Request.class));
verify(lifecycleProcessor).onComplete(
argThat(completionContext -> CompletionContext.Status.FAILED.equals(completionContext.status())));
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private ServerWebExchange mockExchange(ServiceInstance serviceInstance, LoadBalancerLifecycle lifecycleProcessor,
boolean shouldThrowException) {
Response response;
when(lifecycleProcessor.supports(any(Class.class), any(Class.class), any(Class.class))).thenReturn(true);
MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost/get?a=b").build();
URI lbUri = URI.create("lb://service1?a=b");
ServerWebExchange serverWebExchange = MockServerWebExchange.from(request);
if (serviceInstance == null) {
response = new EmptyResponse();
}
else {
response = new DefaultResponse(serviceInstance);
}
serverWebExchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, lbUri);
serverWebExchange.getAttributes().put(GATEWAY_ORIGINAL_REQUEST_URL_ATTR, new LinkedHashSet<>());
serverWebExchange.getAttributes().put(GATEWAY_LOADBALANCER_RESPONSE_ATTR, response);
RoundRobinLoadBalancer loadBalancer = mock(RoundRobinLoadBalancer.class);
when(loadBalancer.choose(any(Request.class))).thenReturn(Mono.just(response));
when(clientFactory.getInstance("service1", ReactorServiceInstanceLoadBalancer.class)).thenReturn(loadBalancer);
Map<String, LoadBalancerLifecycle> lifecycleProcessors = new HashMap<>();
lifecycleProcessors.put("service1", lifecycleProcessor);
when(clientFactory.getInstances("service1", LoadBalancerLifecycle.class)).thenReturn(lifecycleProcessors);
if (shouldThrowException) {
when(chain.filter(any())).thenReturn(Mono.error(new UnsupportedOperationException()));
}
else {
when(chain.filter(any())).thenReturn(Mono.empty());
}
return serverWebExchange;
}
@NotNull

Loading…
Cancel
Save