From c883495f8ac5009e1ca7b266b17af550e4cea404 Mon Sep 17 00:00:00 2001 From: Stephen Oakey Date: Fri, 10 Jul 2015 22:02:24 -0400 Subject: [PATCH] Detect https from Ribbon Client Config RibbonLoadBalancer and RibbonLoadBalancer use IClientConfig IsSecure property for service ID to determine if the service is secure. Replaces scheme of URI of request with https if the request is secure. fixes gh-337 --- .../feign/ribbon/RibbonLoadBalancer.java | 14 ++ .../ribbon/RibbonLoadBalancerClient.java | 31 ++++- .../feign/ribbon/RibbonLoadBalancerTests.java | 130 ++++++++++++++++++ .../ribbon/RibbonLoadBalancerClientTests.java | 49 +++++-- 4 files changed, 206 insertions(+), 18 deletions(-) create mode 100644 spring-cloud-netflix-core/src/test/java/org/springframework/cloud/netflix/feign/ribbon/RibbonLoadBalancerTests.java diff --git a/spring-cloud-netflix-core/src/main/java/org/springframework/cloud/netflix/feign/ribbon/RibbonLoadBalancer.java b/spring-cloud-netflix-core/src/main/java/org/springframework/cloud/netflix/feign/ribbon/RibbonLoadBalancer.java index 52562872..2e73a2b4 100644 --- a/spring-cloud-netflix-core/src/main/java/org/springframework/cloud/netflix/feign/ribbon/RibbonLoadBalancer.java +++ b/spring-cloud-netflix-core/src/main/java/org/springframework/cloud/netflix/feign/ribbon/RibbonLoadBalancer.java @@ -21,6 +21,8 @@ import java.net.URI; import java.util.Collection; import java.util.Map; +import org.springframework.web.util.UriComponentsBuilder; + import com.netflix.client.AbstractLoadBalancerAwareClient; import com.netflix.client.ClientException; import com.netflix.client.ClientRequest; @@ -48,11 +50,14 @@ public class RibbonLoadBalancer private final IClientConfig clientConfig; + private final boolean secure; + public RibbonLoadBalancer(Client delegate, ILoadBalancer lb, IClientConfig clientConfig) { super(lb, clientConfig); this.setRetryHandler(RetryHandler.DEFAULT); this.clientConfig = clientConfig; + this.secure = clientConfig.get(CommonClientConfigKey.IsSecure); this.delegate = delegate; this.connectTimeout = clientConfig.get(CommonClientConfigKey.ConnectTimeout); this.readTimeout = clientConfig.get(CommonClientConfigKey.ReadTimeout); @@ -71,10 +76,19 @@ public class RibbonLoadBalancer else { options = new Request.Options(this.connectTimeout, this.readTimeout); } + if (isSecure(configOverride)) { + URI secureUri = UriComponentsBuilder.fromUri(request.getUri()) + .scheme("https").build().toUri(); + request = new RibbonRequest(request.toRequest(), secureUri); + } Response response = this.delegate.execute(request.toRequest(), options); return new RibbonResponse(request.getUri(), response); } + private boolean isSecure(IClientConfig config) { + return (config != null) ? config.get(CommonClientConfigKey.IsSecure) : secure; + } + @Override public RequestSpecificRetryHandler getRequestSpecificRetryHandler( RibbonRequest request, IClientConfig requestConfig) { diff --git a/spring-cloud-netflix-core/src/main/java/org/springframework/cloud/netflix/ribbon/RibbonLoadBalancerClient.java b/spring-cloud-netflix-core/src/main/java/org/springframework/cloud/netflix/ribbon/RibbonLoadBalancerClient.java index ac1b36f8..1bedb94e 100644 --- a/spring-cloud-netflix-core/src/main/java/org/springframework/cloud/netflix/ribbon/RibbonLoadBalancerClient.java +++ b/spring-cloud-netflix-core/src/main/java/org/springframework/cloud/netflix/ribbon/RibbonLoadBalancerClient.java @@ -25,7 +25,10 @@ import org.springframework.cloud.client.loadbalancer.LoadBalancerClient; import org.springframework.cloud.client.loadbalancer.LoadBalancerRequest; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; +import org.springframework.web.util.UriComponentsBuilder; +import com.netflix.client.config.CommonClientConfigKey; +import com.netflix.client.config.IClientConfig; import com.netflix.loadbalancer.ILoadBalancer; import com.netflix.loadbalancer.Server; import com.netflix.loadbalancer.ServerStats; @@ -50,7 +53,12 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient { RibbonLoadBalancerContext context = this.clientFactory .getLoadBalancerContext(serviceId); Server server = new Server(instance.getHost(), instance.getPort()); - return context.reconstructURIWithServer(server, original); + boolean secure = isSecure(this.clientFactory, serviceId); + URI uri = original; + if(secure) { + uri = UriComponentsBuilder.fromUri(uri).scheme("https").build().toUri(); + } + return context.reconstructURIWithServer(server, uri); } @Override @@ -59,7 +67,7 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient { if (server == null) { return null; } - return new RibbonServer(serviceId, server); + return new RibbonServer(serviceId, server, isSecure(this.clientFactory, serviceId)); } @Override @@ -68,7 +76,7 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient { RibbonLoadBalancerContext context = this.clientFactory .getLoadBalancerContext(serviceId); Server server = getServer(loadBalancer); - RibbonServer ribbonServer = new RibbonServer(serviceId, server); + RibbonServer ribbonServer = new RibbonServer(serviceId, server, isSecure(clientFactory, serviceId)); ServerStats serverStats = context.getServerStats(server); context.noteOpenConnection(serverStats); @@ -85,6 +93,14 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient { } return null; } + + private boolean isSecure(SpringClientFactory clientFactory, String serviceId) { + IClientConfig config = clientFactory.getClientConfig(serviceId); + if(config != null) { + return config.get(CommonClientConfigKey.IsSecure, false); + } + return false; + } private void recordStats(RibbonLoadBalancerContext context, Stopwatch tracer, ServerStats serverStats, Object entity, Throwable exception) { @@ -111,8 +127,13 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient { protected static class RibbonServer implements ServiceInstance { private String serviceId; private Server server; - + private boolean secure; + protected RibbonServer(String serviceId, Server server) { + this(serviceId, server, false); + } + + protected RibbonServer(String serviceId, Server server, boolean secure) { this.serviceId = serviceId; this.server = server; } @@ -134,7 +155,7 @@ public class RibbonLoadBalancerClient implements LoadBalancerClient { @Override public boolean isSecure() { - return false; //TODO: howto determine https from ribbon Server + return this.secure; } @Override diff --git a/spring-cloud-netflix-core/src/test/java/org/springframework/cloud/netflix/feign/ribbon/RibbonLoadBalancerTests.java b/spring-cloud-netflix-core/src/test/java/org/springframework/cloud/netflix/feign/ribbon/RibbonLoadBalancerTests.java new file mode 100644 index 00000000..d4b1dc28 --- /dev/null +++ b/spring-cloud-netflix-core/src/test/java/org/springframework/cloud/netflix/feign/ribbon/RibbonLoadBalancerTests.java @@ -0,0 +1,130 @@ +package org.springframework.cloud.netflix.feign.ribbon; + +import static com.netflix.client.config.CommonClientConfigKey.ConnectTimeout; +import static com.netflix.client.config.CommonClientConfigKey.IsSecure; +import static com.netflix.client.config.CommonClientConfigKey.MaxAutoRetries; +import static com.netflix.client.config.CommonClientConfigKey.MaxAutoRetriesNextServer; +import static com.netflix.client.config.CommonClientConfigKey.OkToRetryOnAllOperations; +import static com.netflix.client.config.CommonClientConfigKey.ReadTimeout; +import static com.netflix.client.config.DefaultClientConfigImpl.DEFAULT_MAX_AUTO_RETRIES; +import static com.netflix.client.config.DefaultClientConfigImpl.DEFAULT_MAX_AUTO_RETRIES_NEXT_SERVER; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.net.URI; +import java.util.Collection; +import java.util.Collections; + +import lombok.SneakyThrows; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.springframework.cloud.netflix.feign.ribbon.RibbonLoadBalancer.RibbonRequest; +import org.springframework.cloud.netflix.feign.ribbon.RibbonLoadBalancer.RibbonResponse; + +import com.netflix.client.config.IClientConfig; +import com.netflix.loadbalancer.ILoadBalancer; + +import feign.Client; +import feign.Request; +import feign.Request.Options; +import feign.RequestTemplate; +import feign.Response; + +public class RibbonLoadBalancerTests { + + @Mock + private Client delegate; + @Mock + private ILoadBalancer lb; + @Mock + private IClientConfig config; + + private RibbonLoadBalancer ribbonLoadBalancer; + + private Integer defaultConnectTimeout = 10000; + private Integer defaultReadTimeout = 10000; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + when(config.get(MaxAutoRetries, DEFAULT_MAX_AUTO_RETRIES)).thenReturn(1); + when(config.get(MaxAutoRetriesNextServer, DEFAULT_MAX_AUTO_RETRIES_NEXT_SERVER)) + .thenReturn(1); + when(config.get(OkToRetryOnAllOperations, eq(anyBoolean()))).thenReturn(true); + when(config.get(ConnectTimeout)).thenReturn(defaultConnectTimeout); + when(config.get(ReadTimeout)).thenReturn(defaultReadTimeout); + } + + @Test + @SneakyThrows + public void testUriInsecure() { + when(config.get(IsSecure)).thenReturn(false); + ribbonLoadBalancer = new RibbonLoadBalancer(delegate, lb, config); + Request request = new RequestTemplate().method("GET").append("http://foo/") + .request(); + RibbonRequest ribbonRequest = new RibbonRequest(request, new URI(request.url())); + + Response response = Response.create(200, "Test", + Collections.> emptyMap(), new byte[0]); + when(delegate.execute(any(Request.class), any(Options.class))).thenReturn( + response); + + RibbonResponse resp = ribbonLoadBalancer.execute(ribbonRequest, null); + + assertThat(resp.getRequestedURI(), is(new URI("http://foo/"))); + } + + @Test + @SneakyThrows + public void testSecureUriFromClientConfig() { + when(config.get(IsSecure)).thenReturn(true); + ribbonLoadBalancer = new RibbonLoadBalancer(delegate, lb, config); + Request request = new RequestTemplate().method("GET").append("http://foo/") + .request(); + RibbonRequest ribbonRequest = new RibbonRequest(request, new URI(request.url())); + + Response response = Response.create(200, "Test", + Collections.> emptyMap(), new byte[0]); + when(delegate.execute(any(Request.class), any(Options.class))).thenReturn( + response); + + RibbonResponse resp = ribbonLoadBalancer.execute(ribbonRequest, null); + + assertThat(resp.getRequestedURI(), is(new URI("https://foo/"))); + } + + @Test + @SneakyThrows + public void testSecureUriFromClientConfigOverride() { + when(config.get(IsSecure)).thenReturn(true); + ribbonLoadBalancer = new RibbonLoadBalancer(delegate, lb, config); + Request request = new RequestTemplate().method("GET").append("http://foo/") + .request(); + RibbonRequest ribbonRequest = new RibbonRequest(request, new URI(request.url())); + + Response response = Response.create(200, "Test", + Collections.> emptyMap(), new byte[0]); + when(delegate.execute(any(Request.class), any(Options.class))).thenReturn( + response); + + IClientConfig override = mock(IClientConfig.class); + when(override.get(ConnectTimeout, defaultConnectTimeout)).thenReturn(5000); + when(override.get(ReadTimeout, defaultConnectTimeout)).thenReturn(5000); + /* + * Override secure value. + */ + when(override.get(IsSecure)).thenReturn(false); + + RibbonResponse resp = ribbonLoadBalancer.execute(ribbonRequest, override); + + assertThat(resp.getRequestedURI(), is(new URI("http://foo/"))); + } +} diff --git a/spring-cloud-netflix-core/src/test/java/org/springframework/cloud/netflix/ribbon/RibbonLoadBalancerClientTests.java b/spring-cloud-netflix-core/src/test/java/org/springframework/cloud/netflix/ribbon/RibbonLoadBalancerClientTests.java index aff17a40..5c4c9eae 100644 --- a/spring-cloud-netflix-core/src/test/java/org/springframework/cloud/netflix/ribbon/RibbonLoadBalancerClientTests.java +++ b/spring-cloud-netflix-core/src/test/java/org/springframework/cloud/netflix/ribbon/RibbonLoadBalancerClientTests.java @@ -16,10 +16,23 @@ package org.springframework.cloud.netflix.ribbon; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; +import static org.mockito.BDDMockito.given; +import static org.mockito.Matchers.anyDouble; +import static org.mockito.Matchers.anyObject; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + import java.net.URI; import java.net.URL; import lombok.SneakyThrows; + import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -28,20 +41,13 @@ import org.springframework.cloud.client.ServiceInstance; import org.springframework.cloud.client.loadbalancer.LoadBalancerRequest; import org.springframework.cloud.netflix.ribbon.RibbonLoadBalancerClient.RibbonServer; +import com.netflix.client.config.CommonClientConfigKey; +import com.netflix.client.config.IClientConfig; import com.netflix.loadbalancer.BaseLoadBalancer; import com.netflix.loadbalancer.LoadBalancerStats; import com.netflix.loadbalancer.Server; import com.netflix.loadbalancer.ServerStats; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.fail; -import static org.junit.Assert.assertNull; -import static org.mockito.BDDMockito.given; -import static org.mockito.Matchers.anyDouble; -import static org.mockito.Matchers.anyString; -import static org.mockito.Mockito.verify; - /** * @author Spencer Gibb */ @@ -81,10 +87,27 @@ public class RibbonLoadBalancerClientTests { RibbonServer server = getRibbonServer(); RibbonLoadBalancerClient client = getRibbonLoadBalancerClient(server); ServiceInstance serviceInstance = client.choose(server.getServiceId()); - URI uri = client.reconstructURI(serviceInstance, new URL(scheme +"://" - + server.getServiceId()).toURI()); + URI uri = client.reconstructURI(serviceInstance, + new URL(scheme + "://" + server.getServiceId()).toURI()); + assertEquals(server.getHost(), uri.getHost()); + assertEquals(server.getPort(), uri.getPort()); + } + + @Test + @SneakyThrows + public void testReconstructUriWithSecureClientConfig() { + RibbonServer server = getRibbonServer(); + IClientConfig config = mock(IClientConfig.class); + when(config.get(CommonClientConfigKey.IsSecure, false)).thenReturn(true); + when(clientFactory.getClientConfig(server.getServiceId())).thenReturn(config); + + RibbonLoadBalancerClient client = getRibbonLoadBalancerClient(server); + ServiceInstance serviceInstance = client.choose(server.getServiceId()); + URI uri = client.reconstructURI(serviceInstance, + new URL("http://" + server.getServiceId()).toURI()); assertEquals(server.getHost(), uri.getHost()); assertEquals(server.getPort(), uri.getPort()); + assertEquals("https", uri.getScheme()); } @Test @@ -166,8 +189,8 @@ public class RibbonLoadBalancerClientTests { protected RibbonLoadBalancerClient getRibbonLoadBalancerClient( RibbonServer ribbonServer) { given(this.loadBalancer.getName()).willReturn(ribbonServer.getServiceId()); - given(this.loadBalancer.chooseServer(anyString())) - .willReturn(ribbonServer.getServer()); + given(this.loadBalancer.chooseServer(anyObject())).willReturn( + ribbonServer.getServer()); given(this.loadBalancer.getLoadBalancerStats()) .willReturn(this.loadBalancerStats); given(this.loadBalancerStats.getSingleServerStat(ribbonServer.getServer()))