@ -22,20 +22,24 @@ import java.util.LinkedHashSet;
import java.util.Map ;
import java.util.Map ;
import org.jetbrains.annotations.NotNull ;
import org.jetbrains.annotations.NotNull ;
import org.junit.Before ;
import org.junit.jupiter.api. BeforeEach ;
import org.junit.Test ;
import org.junit.jupiter.api. Test ;
import org.junit.runner.Run With ;
import org.junit.jupiter.api.extension.Extend With ;
import org.mockito.ArgumentCaptor ;
import org.mockito.ArgumentCaptor ;
import org.mockito.InjectMocks ;
import org.mockito.InjectMocks ;
import org.mockito.Mock ;
import org.mockito.Mock ;
import org.mockito.junit.MockitoJUnitRunner ;
import org.mockito.junit.jupiter.MockitoExtension ;
import reactor.core.publisher.Mono ;
import reactor.core.publisher.Mono ;
import org.springframework.cloud.client.DefaultServiceInstance ;
import org.springframework.cloud.client.DefaultServiceInstance ;
import org.springframework.cloud.client.ServiceInstance ;
import org.springframework.cloud.client.ServiceInstance ;
import org.springframework.cloud.client.loadbalancer.CompletionContext ;
import org.springframework.cloud.client.loadbalancer.DefaultResponse ;
import org.springframework.cloud.client.loadbalancer.DefaultResponse ;
import org.springframework.cloud.client.loadbalancer.EmptyResponse ;
import org.springframework.cloud.client.loadbalancer.LoadBalancerLifecycle ;
import org.springframework.cloud.client.loadbalancer.Request ;
import org.springframework.cloud.client.loadbalancer.Request ;
import org.springframework.cloud.client.loadbalancer.ServerHttpRequestContext ;
import org.springframework.cloud.client.loadbalancer.RequestDataContext ;
import org.springframework.cloud.client.loadbalancer.Response ;
import org.springframework.cloud.client.loadbalancer.reactive.LoadBalancerProperties ;
import org.springframework.cloud.client.loadbalancer.reactive.LoadBalancerProperties ;
import org.springframework.cloud.gateway.config.GatewayLoadBalancerProperties ;
import org.springframework.cloud.gateway.config.GatewayLoadBalancerProperties ;
import org.springframework.cloud.gateway.support.NotFoundException ;
import org.springframework.cloud.gateway.support.NotFoundException ;
@ -51,6 +55,7 @@ import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder ;
import org.springframework.web.util.UriComponentsBuilder ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.ArgumentMatchers.argThat ;
import static org.mockito.ArgumentMatchers.argThat ;
import static org.mockito.Mockito.mock ;
import static org.mockito.Mockito.mock ;
@ -58,6 +63,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions ;
import static org.mockito.Mockito.verifyNoInteractions ;
import static org.mockito.Mockito.verifyNoMoreInteractions ;
import static org.mockito.Mockito.verifyNoMoreInteractions ;
import static org.mockito.Mockito.when ;
import static org.mockito.Mockito.when ;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_LOADBALANCER_RESPONSE_ATTR ;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ORIGINAL_REQUEST_URL_ATTR ;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_ORIGINAL_REQUEST_URL_ATTR ;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR ;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR ;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_SCHEME_PREFIX_ATTR ;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_SCHEME_PREFIX_ATTR ;
@ -70,8 +76,8 @@ import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.G
* @author Olga Maciaszek - Sharma
* @author Olga Maciaszek - Sharma
* /
* /
@SuppressWarnings ( "UnassignedFluxMonoInstance" )
@SuppressWarnings ( "UnassignedFluxMonoInstance" )
@RunWith ( MockitoJUnitRunner . class )
@ExtendWith ( MockitoExtension . class )
public class ReactiveLoadBalancerClientFilterTests {
class ReactiveLoadBalancerClientFilterTests {
private ServerWebExchange exchange ;
private ServerWebExchange exchange ;
@ -89,14 +95,14 @@ public class ReactiveLoadBalancerClientFilterTests {
@InjectMocks
@InjectMocks
private ReactiveLoadBalancerClientFilter filter ;
private ReactiveLoadBalancerClientFilter filter ;
@Before
@BeforeEach
public void setup ( ) {
void setup ( ) {
properties = new GatewayLoadBalancerProperties ( ) ;
properties = new GatewayLoadBalancerProperties ( ) ;
exchange = MockServerWebExchange . from ( MockServerHttpRequest . get ( "/mypath" ) . build ( ) ) ;
exchange = MockServerWebExchange . from ( MockServerHttpRequest . get ( "/mypath" ) . build ( ) ) ;
}
}
@Test
@Test
public void shouldNotFilterWhenGatewayRequestUrlIsMissing ( ) {
void shouldNotFilterWhenGatewayRequestUrlIsMissing ( ) {
filter . filter ( exchange , chain ) ;
filter . filter ( exchange , chain ) ;
verify ( chain ) . filter ( exchange ) ;
verify ( chain ) . filter ( exchange ) ;
@ -105,7 +111,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
}
@Test
@Test
public void shouldNotFilterWhenGatewayRequestUrlSchemeIsNotLb ( ) {
void shouldNotFilterWhenGatewayRequestUrlSchemeIsNotLb ( ) {
URI uri = UriComponentsBuilder . fromUriString ( "http://myservice" ) . build ( ) . toUri ( ) ;
URI uri = UriComponentsBuilder . fromUriString ( "http://myservice" ) . build ( ) . toUri ( ) ;
exchange . getAttributes ( ) . put ( GATEWAY_REQUEST_URL_ATTR , uri ) ;
exchange . getAttributes ( ) . put ( GATEWAY_REQUEST_URL_ATTR , uri ) ;
@ -116,17 +122,19 @@ public class ReactiveLoadBalancerClientFilterTests {
verifyNoInteractions ( clientFactory ) ;
verifyNoInteractions ( clientFactory ) ;
}
}
@Test ( expected = NotFoundException . class )
@Test
public void shouldThrowExceptionWhenNoServiceInstanceIsFound ( ) {
void shouldThrowExceptionWhenNoServiceInstanceIsFound ( ) {
URI uri = UriComponentsBuilder . fromUriString ( "lb://myservice" ) . build ( ) . toUri ( ) ;
assertThatExceptionOfType ( NotFoundException . class ) . isThrownBy ( ( ) - > {
exchange . getAttributes ( ) . put ( GATEWAY_REQUEST_URL_ATTR , uri ) ;
URI uri = UriComponentsBuilder . fromUriString ( "lb://myservice" ) . build ( ) . toUri ( ) ;
exchange . getAttributes ( ) . put ( GATEWAY_REQUEST_URL_ATTR , uri ) ;
filter . filter ( exchange , chain ) . block ( ) ;
filter . filter ( exchange , chain ) . block ( ) ;
} ) ;
}
}
@SuppressWarnings ( "unchecked" )
@SuppressWarnings ( "unchecked" )
@Test
@Test
public void shouldFilter ( ) {
void shouldFilter ( ) {
URI url = UriComponentsBuilder . fromUriString ( "lb://myservice" ) . build ( ) . toUri ( ) ;
URI url = UriComponentsBuilder . fromUriString ( "lb://myservice" ) . build ( ) . toUri ( ) ;
exchange . getAttributes ( ) . put ( GATEWAY_REQUEST_URL_ATTR , url ) ;
exchange . getAttributes ( ) . put ( GATEWAY_REQUEST_URL_ATTR , url ) ;
@ -145,6 +153,8 @@ public class ReactiveLoadBalancerClientFilterTests {
verify ( clientFactory ) . getInstance ( "myservice" , ReactorServiceInstanceLoadBalancer . class ) ;
verify ( clientFactory ) . getInstance ( "myservice" , ReactorServiceInstanceLoadBalancer . class ) ;
verify ( clientFactory ) . getInstances ( "myservice" , LoadBalancerLifecycle . class ) ;
verifyNoMoreInteractions ( clientFactory ) ;
verifyNoMoreInteractions ( clientFactory ) ;
assertThat ( ( URI ) exchange . getAttribute ( GATEWAY_REQUEST_URL_ATTR ) )
assertThat ( ( URI ) exchange . getAttribute ( GATEWAY_REQUEST_URL_ATTR ) )
@ -155,7 +165,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
}
@Test
@Test
public void happyPath ( ) {
void happyPath ( ) {
MockServerHttpRequest request = MockServerHttpRequest . get ( "http://localhost/get?a=b" ) . build ( ) ;
MockServerHttpRequest request = MockServerHttpRequest . get ( "http://localhost/get?a=b" ) . build ( ) ;
URI lbUri = URI . create ( "lb://service1?a=b" ) ;
URI lbUri = URI . create ( "lb://service1?a=b" ) ;
@ -165,7 +175,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
}
@Test
@Test
public void noQueryParams ( ) {
void noQueryParams ( ) {
MockServerHttpRequest request = MockServerHttpRequest . get ( "http://localhost/get" ) . build ( ) ;
MockServerHttpRequest request = MockServerHttpRequest . get ( "http://localhost/get" ) . build ( ) ;
ServerWebExchange webExchange = testFilter ( request , URI . create ( "lb://service1" ) ) ;
ServerWebExchange webExchange = testFilter ( request , URI . create ( "lb://service1" ) ) ;
@ -174,7 +184,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
}
@Test
@Test
public void encodedParameters ( ) {
void encodedParameters ( ) {
URI url = UriComponentsBuilder . fromUriString ( "http://localhost/get?a=b&c=d[]" ) . buildAndExpand ( ) . encode ( )
URI url = UriComponentsBuilder . fromUriString ( "http://localhost/get?a=b&c=d[]" ) . buildAndExpand ( ) . encode ( )
. toUri ( ) ;
. toUri ( ) ;
@ -196,7 +206,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
}
@Test
@Test
public void unencodedParameters ( ) {
void unencodedParameters ( ) {
URI url = URI . create ( "http://localhost/get?a=b&c=d[]" ) ;
URI url = URI . create ( "http://localhost/get?a=b&c=d[]" ) ;
MockServerHttpRequest request = MockServerHttpRequest . method ( HttpMethod . GET , url ) . build ( ) ;
MockServerHttpRequest request = MockServerHttpRequest . method ( HttpMethod . GET , url ) . build ( ) ;
@ -216,7 +226,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
}
@Test
@Test
public void happyPathWithAttributeRatherThanScheme ( ) {
void happyPathWithAttributeRatherThanScheme ( ) {
MockServerHttpRequest request = MockServerHttpRequest . get ( "ws://localhost/get?a=b" ) . build ( ) ;
MockServerHttpRequest request = MockServerHttpRequest . get ( "ws://localhost/get?a=b" ) . build ( ) ;
URI lbUri = URI . create ( "ws://service1?a=b" ) ;
URI lbUri = URI . create ( "ws://service1?a=b" ) ;
@ -230,7 +240,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
}
@Test
@Test
public void shouldNotFilterWhenGatewaySchemePrefixAttrIsNotLb ( ) {
void shouldNotFilterWhenGatewaySchemePrefixAttrIsNotLb ( ) {
URI uri = UriComponentsBuilder . fromUriString ( "http://myservice" ) . build ( ) . toUri ( ) ;
URI uri = UriComponentsBuilder . fromUriString ( "http://myservice" ) . build ( ) . toUri ( ) ;
exchange . getAttributes ( ) . put ( GATEWAY_REQUEST_URL_ATTR , uri ) ;
exchange . getAttributes ( ) . put ( GATEWAY_REQUEST_URL_ATTR , uri ) ;
exchange . getAttributes ( ) . put ( GATEWAY_SCHEME_PREFIX_ATTR , "xx" ) ;
exchange . getAttributes ( ) . put ( GATEWAY_SCHEME_PREFIX_ATTR , "xx" ) ;
@ -243,7 +253,7 @@ public class ReactiveLoadBalancerClientFilterTests {
}
}
@Test
@Test
public void shouldThrow4O4ExceptionWhenNoServiceInstanceIsFound ( ) {
void shouldThrow4O4ExceptionWhenNoServiceInstanceIsFound ( ) {
URI uri = UriComponentsBuilder . fromUriString ( "lb://service1" ) . build ( ) . toUri ( ) ;
URI uri = UriComponentsBuilder . fromUriString ( "lb://service1" ) . build ( ) . toUri ( ) ;
exchange . getAttributes ( ) . put ( GATEWAY_REQUEST_URL_ATTR , uri ) ;
exchange . getAttributes ( ) . put ( GATEWAY_REQUEST_URL_ATTR , uri ) ;
RoundRobinLoadBalancer loadBalancer = new RoundRobinLoadBalancer (
RoundRobinLoadBalancer loadBalancer = new RoundRobinLoadBalancer (
@ -263,7 +273,7 @@ public class ReactiveLoadBalancerClientFilterTests {
@SuppressWarnings ( "unchecked" )
@SuppressWarnings ( "unchecked" )
@Test
@Test
public void shouldOverrideSchemeUsingIsSecure ( ) {
void shouldOverrideSchemeUsingIsSecure ( ) {
URI url = UriComponentsBuilder . fromUriString ( "lb://myservice" ) . build ( ) . toUri ( ) ;
URI url = UriComponentsBuilder . fromUriString ( "lb://myservice" ) . build ( ) . toUri ( ) ;
ServerWebExchange exchange = MockServerWebExchange
ServerWebExchange exchange = MockServerWebExchange
. from ( MockServerHttpRequest . get ( "https://localhost:9999/mypath" ) . build ( ) ) ;
. from ( MockServerHttpRequest . get ( "https://localhost:9999/mypath" ) . build ( ) ) ;
@ -286,7 +296,7 @@ public class ReactiveLoadBalancerClientFilterTests {
@SuppressWarnings ( { "rawtypes" } )
@SuppressWarnings ( { "rawtypes" } )
@Test
@Test
public void shouldPassRequestToLoadBalancer ( ) {
void shouldPassRequestToLoadBalancer ( ) {
String hint = "test" ;
String hint = "test" ;
when ( loadBalancerProperties . getHint ( ) ) . thenReturn ( buildHints ( hint ) ) ;
when ( loadBalancerProperties . getHint ( ) ) . thenReturn ( buildHints ( hint ) ) ;
MockServerHttpRequest request = MockServerHttpRequest . get ( "http://localhost/get?a=b" ) . build ( ) ;
MockServerHttpRequest request = MockServerHttpRequest . get ( "http://localhost/get?a=b" ) . build ( ) ;
@ -298,17 +308,92 @@ public class ReactiveLoadBalancerClientFilterTests {
when ( serverWebExchange . getRequest ( ) ) . thenReturn ( request ) ;
when ( serverWebExchange . getRequest ( ) ) . thenReturn ( request ) ;
RoundRobinLoadBalancer loadBalancer = mock ( RoundRobinLoadBalancer . class ) ;
RoundRobinLoadBalancer loadBalancer = mock ( RoundRobinLoadBalancer . class ) ;
when ( loadBalancer . choose ( any ( Request . class ) ) ) . thenReturn ( Mono . just (
when ( loadBalancer . choose ( any ( Request . class ) ) ) . thenReturn ( Mono . just (
new DefaultResponse ( new DefaultServiceInstance ( "myservice1" , "my service" , "localhost" , 8080 , false ) ) ) ) ;
new DefaultResponse ( new DefaultServiceInstance ( "myservice1" , "service1 " , "localhost" , 8080 , false ) ) ) ) ;
when ( clientFactory . getInstance ( "service1" , ReactorServiceInstanceLoadBalancer . class ) ) . thenReturn ( loadBalancer ) ;
when ( clientFactory . getInstance ( "service1" , ReactorServiceInstanceLoadBalancer . class ) ) . thenReturn ( loadBalancer ) ;
when ( chain . filter ( any ( ) ) ) . thenReturn ( Mono . empty ( ) ) ;
when ( chain . filter ( any ( ) ) ) . thenReturn ( Mono . empty ( ) ) ;
filter . filter ( serverWebExchange , chain ) ;
filter . filter ( serverWebExchange , chain ) ;
verify ( loadBalancer )
verify ( loadBalancer ) . choose ( argThat ( ( Request passedRequest ) - > ( ( RequestDataContext ) passedRequest . getContext ( ) )
. choose ( argThat ( ( Request passedRequest ) - > ( ( ServerHttpRequestContext ) passedRequest . getContext ( ) )
. getClientRequest ( ) . getUrl ( ) . equals ( request . getURI ( ) )
. getClientRequest ( ) . equals ( request )
& & ( ( RequestDataContext ) passedRequest . getContext ( ) ) . getHint ( ) . equals ( hint ) ) ) ;
& & ( ( ServerHttpRequestContext ) passedRequest . getContext ( ) ) . getHint ( ) . equals ( hint ) ) ) ;
}
@SuppressWarnings ( { "unchecked" , "rawtypes" } )
@Test
void loadBalancerLifecycleCallbacksExecutedForSuccess ( ) {
LoadBalancerLifecycle lifecycleProcessor = mock ( LoadBalancerLifecycle . class ) ;
ServiceInstance serviceInstance = new DefaultServiceInstance ( "myservice1" , "myservice" , "localhost" , 8080 ,
false ) ;
ServerWebExchange serverWebExchange = mockExchange ( serviceInstance , lifecycleProcessor , false ) ;
filter . filter ( serverWebExchange , chain ) . subscribe ( ) ;
verify ( lifecycleProcessor ) . onStart ( any ( Request . class ) ) ;
verify ( lifecycleProcessor ) . onComplete (
argThat ( completionContext - > CompletionContext . Status . SUCCESS . equals ( completionContext . status ( ) )
& & completionContext . getLoadBalancerResponse ( ) . getServer ( ) . equals ( serviceInstance ) ) ) ;
}
@SuppressWarnings ( { "unchecked" , "rawtypes" } )
@Test
void loadBalancerLifecycleCallbacksExecutedForDiscard ( ) {
LoadBalancerLifecycle lifecycleProcessor = mock ( LoadBalancerLifecycle . class ) ;
ServiceInstance serviceInstance = null ;
ServerWebExchange serverWebExchange = mockExchange ( serviceInstance , lifecycleProcessor , false ) ;
filter . filter ( serverWebExchange , chain ) . subscribe ( ) ;
verify ( lifecycleProcessor ) . onStart ( any ( Request . class ) ) ;
verify ( lifecycleProcessor ) . onComplete (
argThat ( completionContext - > CompletionContext . Status . DISCARD . equals ( completionContext . status ( ) ) ) ) ;
}
@SuppressWarnings ( { "unchecked" , "rawtypes" } )
@Test
void loadBalancerLifecycleCallbacksExecutedForFailed ( ) {
LoadBalancerLifecycle lifecycleProcessor = mock ( LoadBalancerLifecycle . class ) ;
ServiceInstance serviceInstance = null ;
ServerWebExchange serverWebExchange = mockExchange ( serviceInstance , lifecycleProcessor , true ) ;
filter . filter ( serverWebExchange , chain ) . subscribe ( ) ;
verify ( lifecycleProcessor ) . onStart ( any ( Request . class ) ) ;
verify ( lifecycleProcessor ) . onComplete (
argThat ( completionContext - > CompletionContext . Status . FAILED . equals ( completionContext . status ( ) ) ) ) ;
}
@SuppressWarnings ( { "rawtypes" , "unchecked" } )
private ServerWebExchange mockExchange ( ServiceInstance serviceInstance , LoadBalancerLifecycle lifecycleProcessor ,
boolean shouldThrowException ) {
Response response ;
when ( lifecycleProcessor . supports ( any ( Class . class ) , any ( Class . class ) , any ( Class . class ) ) ) . thenReturn ( true ) ;
MockServerHttpRequest request = MockServerHttpRequest . get ( "http://localhost/get?a=b" ) . build ( ) ;
URI lbUri = URI . create ( "lb://service1?a=b" ) ;
ServerWebExchange serverWebExchange = MockServerWebExchange . from ( request ) ;
if ( serviceInstance = = null ) {
response = new EmptyResponse ( ) ;
}
else {
response = new DefaultResponse ( serviceInstance ) ;
}
serverWebExchange . getAttributes ( ) . put ( GATEWAY_REQUEST_URL_ATTR , lbUri ) ;
serverWebExchange . getAttributes ( ) . put ( GATEWAY_ORIGINAL_REQUEST_URL_ATTR , new LinkedHashSet < > ( ) ) ;
serverWebExchange . getAttributes ( ) . put ( GATEWAY_LOADBALANCER_RESPONSE_ATTR , response ) ;
RoundRobinLoadBalancer loadBalancer = mock ( RoundRobinLoadBalancer . class ) ;
when ( loadBalancer . choose ( any ( Request . class ) ) ) . thenReturn ( Mono . just ( response ) ) ;
when ( clientFactory . getInstance ( "service1" , ReactorServiceInstanceLoadBalancer . class ) ) . thenReturn ( loadBalancer ) ;
Map < String , LoadBalancerLifecycle > lifecycleProcessors = new HashMap < > ( ) ;
lifecycleProcessors . put ( "service1" , lifecycleProcessor ) ;
when ( clientFactory . getInstances ( "service1" , LoadBalancerLifecycle . class ) ) . thenReturn ( lifecycleProcessors ) ;
if ( shouldThrowException ) {
when ( chain . filter ( any ( ) ) ) . thenReturn ( Mono . error ( new UnsupportedOperationException ( ) ) ) ;
}
else {
when ( chain . filter ( any ( ) ) ) . thenReturn ( Mono . empty ( ) ) ;
}
return serverWebExchange ;
}
}
@NotNull
@NotNull