From 706fbc0df2383d5e329e04c1133366098b715b0a Mon Sep 17 00:00:00 2001 From: jphilippeplante Date: Wed, 7 Feb 2018 14:13:57 -0500 Subject: [PATCH] Enhancement: adding IPv6 support to Remote Addr Predicate gh-165 (#183) * Enhancement: adding IPv6 support to RemoteAddrRoutePredicateFactory * Using IpSubnetFilterRule from Netty as suggested fixes gh-165 --- .../main/asciidoc/spring-cloud-gateway.adoc | 2 +- .../RemoteAddrRoutePredicateFactory.java | 33 +- .../cloud/gateway/support/SubnetUtils.java | 364 ------------------ .../RemoteAddrRoutePredicateFactoryTests.java | 53 +++ .../resources/application-remote-address.yml | 27 ++ 5 files changed, 97 insertions(+), 382 deletions(-) delete mode 100644 spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/support/SubnetUtils.java create mode 100644 spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/handler/predicate/RemoteAddrRoutePredicateFactoryTests.java create mode 100644 spring-cloud-gateway-core/src/test/resources/application-remote-address.yml diff --git a/docs/src/main/asciidoc/spring-cloud-gateway.adoc b/docs/src/main/asciidoc/spring-cloud-gateway.adoc index 2d8d9dbf2..197aff920 100644 --- a/docs/src/main/asciidoc/spring-cloud-gateway.adoc +++ b/docs/src/main/asciidoc/spring-cloud-gateway.adoc @@ -221,7 +221,7 @@ This route would match if the request contained a `foo` query parameter whose va === RemoteAddr Route Predicate Factory -The RemoteAddr Route Predicate Factory takes a list (min size 1) of CIDR-notation strings, e.g. `192.168.0.1/16` (where `192.168.0.1` is an IP address and `16` is a subnet mask. +The RemoteAddr Route Predicate Factory takes a list (min size 1) of CIDR-notation (IPv4 or IPv6) strings, e.g. `192.168.0.1/16` (where `192.168.0.1` is an IP address and `16` is a subnet mask. .application.yml [source,yaml] diff --git a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/handler/predicate/RemoteAddrRoutePredicateFactory.java b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/handler/predicate/RemoteAddrRoutePredicateFactory.java index b7706fc35..7d924f164 100644 --- a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/handler/predicate/RemoteAddrRoutePredicateFactory.java +++ b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/handler/predicate/RemoteAddrRoutePredicateFactory.java @@ -24,11 +24,13 @@ import java.util.function.Predicate; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.cloud.gateway.support.SubnetUtils; import org.springframework.tuple.Tuple; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; +import io.netty.handler.ipfilter.IpFilterRuleType; +import io.netty.handler.ipfilter.IpSubnetFilterRule; + /** * @author Spencer Gibb */ @@ -38,9 +40,9 @@ public class RemoteAddrRoutePredicateFactory implements RoutePredicateFactory { @Override public Predicate apply(Tuple args) { - validate(1, args); + validateMin(1, args); - List sources = new ArrayList<>(); + List sources = new ArrayList<>(); if (args != null) { for (Object arg : args.getValues()) { addSource(sources, (String) arg); @@ -52,14 +54,14 @@ public class RemoteAddrRoutePredicateFactory implements RoutePredicateFactory { public Predicate apply(String... addrs) { Assert.notEmpty(addrs, "addrs must not be empty"); - List sources = new ArrayList<>(); + List sources = new ArrayList<>(); for (String addr : addrs) { addSource(sources, addr); } return apply(sources); } - public Predicate apply(List sources) { + public Predicate apply(List sources) { return exchange -> { InetSocketAddress remoteAddress = exchange.getRequest().getRemoteAddress(); if (remoteAddress != null) { @@ -70,8 +72,8 @@ public class RemoteAddrRoutePredicateFactory implements RoutePredicateFactory { log.warn("Remote addresses didn't match " + hostAddress + " != " + host); } - for (SubnetUtils source : sources) { - if (source.getInfo().isInRange(hostAddress)) { + for (IpSubnetFilterRule source : sources) { + if (source.matches(remoteAddress)) { return true; } } @@ -81,18 +83,15 @@ public class RemoteAddrRoutePredicateFactory implements RoutePredicateFactory { }; } - private void addSource(List sources, String source) { - boolean inclusiveHostCount = false; + private void addSource(List sources, String source) { if (!source.contains("/")) { // no netmask, add default source = source + "/32"; } - if (source.endsWith("/32")) { - //http://stackoverflow.com/questions/2942299/converting-cidr-address-to-subnet-mask-and-network-address#answer-6858429 - inclusiveHostCount = true; - } - //TODO: howto support ipv6 as well? - SubnetUtils subnetUtils = new SubnetUtils(source); - subnetUtils.setInclusiveHostCount(inclusiveHostCount); - sources.add(subnetUtils); + + String[] ipAddressCidrPrefix = source.split("/",2); + String ipAddress = ipAddressCidrPrefix[0]; + int cidrPrefix = Integer.parseInt(ipAddressCidrPrefix[1]); + + sources.add(new IpSubnetFilterRule(ipAddress, cidrPrefix, IpFilterRuleType.ACCEPT)); } } diff --git a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/support/SubnetUtils.java b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/support/SubnetUtils.java deleted file mode 100644 index fc7193c39..000000000 --- a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/support/SubnetUtils.java +++ /dev/null @@ -1,364 +0,0 @@ -/* - * Copyright 2013-2017 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package org.springframework.cloud.gateway.support; - -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -/** - * A class that performs some subnet calculations given a network address and a subnet mask. - * See original from commons-net org.apache.commons.net.util.SubnetUtils - * @see "http://www.faqs.org/rfcs/rfc1519.html" - */ -@SuppressWarnings("unused") -public class SubnetUtils { - - private static final String IP_ADDRESS = "(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})"; - private static final String SLASH_FORMAT = IP_ADDRESS + "/(\\d{1,3})"; - private static final Pattern addressPattern = Pattern.compile(IP_ADDRESS); - private static final Pattern cidrPattern = Pattern.compile(SLASH_FORMAT); - private static final int NBITS = 32; - - private int netmask = 0; - private int address = 0; - private int network = 0; - private int broadcast = 0; - - /** Whether the broadcast/network address are included in host count */ - private boolean inclusiveHostCount = false; - - - /** - * Constructor that takes a CIDR-notation string, e.g. "192.168.0.1/16" - * @param cidrNotation A CIDR-notation string, e.g. "192.168.0.1/16" - * @throws IllegalArgumentException if the parameter is invalid, - * i.e. does not match n.n.n.n/m where n=1-3 decimal digits, m = 1-3 decimal digits in range 1-32 - */ - public SubnetUtils(String cidrNotation) { - calculate(cidrNotation); - } - - /** - * Constructor that takes a dotted decimal address and a dotted decimal mask. - * @param address An IP address, e.g. "192.168.0.1" - * @param mask A dotted decimal netmask e.g. "255.255.0.0" - * @throws IllegalArgumentException if the address or mask is invalid, - * i.e. does not match n.n.n.n where n=1-3 decimal digits and the mask is not all zeros - */ - public SubnetUtils(String address, String mask) { - calculate(toCidrNotation(address, mask)); - } - - - /** - * Returns true if the return value of {@link SubnetInfo#getAddressCount()} - * includes the network and broadcast addresses. - * @since 2.2 - * @return true if the hostcount includes the network and broadcast addresses - */ - public boolean isInclusiveHostCount() { - return inclusiveHostCount; - } - - /** - * Set to true if you want the return value of {@link SubnetInfo#getAddressCount()} - * to include the network and broadcast addresses. - * @param inclusiveHostCount true if network and broadcast addresses are to be included - * @since 2.2 - */ - public void setInclusiveHostCount(boolean inclusiveHostCount) { - this.inclusiveHostCount = inclusiveHostCount; - } - - - - /** - * Convenience container for subnet summary information. - * - */ - public final class SubnetInfo { - /* Mask to convert unsigned int to a long (i.e. keep 32 bits) */ - private static final long UNSIGNED_INT_MASK = 0x0FFFFFFFFL; - - private SubnetInfo() {} - - private int netmask() { return netmask; } - private int network() { return network; } - private int address() { return address; } - private int broadcast() { return broadcast; } - - // long versions of the values (as unsigned int) which are more suitable for range checking - private long networkLong() { return network & UNSIGNED_INT_MASK; } - private long broadcastLong(){ return broadcast & UNSIGNED_INT_MASK; } - - private int low() { - return (isInclusiveHostCount() ? network() : - broadcastLong() - networkLong() > 1 ? network() + 1 : 0); - } - - private int high() { - return (isInclusiveHostCount() ? broadcast() : - broadcastLong() - networkLong() > 1 ? broadcast() -1 : 0); - } - - /** - * Returns true if the parameter address is in the - * range of usable endpoint addresses for this subnet. This excludes the - * network and broadcast adresses. - * @param address A dot-delimited IPv4 address, e.g. "192.168.0.1" - * @return True if in range, false otherwise - */ - public boolean isInRange(String address) { - return isInRange(toInteger(address)); - } - - /** - * - * @param address the address to check - * @return true if it is in range - * @since 3.4 (made public) - */ - public boolean isInRange(int address) { - long addLong = address & UNSIGNED_INT_MASK; - long lowLong = low() & UNSIGNED_INT_MASK; - long highLong = high() & UNSIGNED_INT_MASK; - return addLong >= lowLong && addLong <= highLong; - } - - public String getBroadcastAddress() { - return format(toArray(broadcast())); - } - - public String getNetworkAddress() { - return format(toArray(network())); - } - - public String getNetmask() { - return format(toArray(netmask())); - } - - public String getAddress() { - return format(toArray(address())); - } - - /** - * Return the low address as a dotted IP address. - * Will be zero for CIDR/31 and CIDR/32 if the inclusive flag is false. - * - * @return the IP address in dotted format, may be "0.0.0.0" if there is no valid address - */ - public String getLowAddress() { - return format(toArray(low())); - } - - /** - * Return the high address as a dotted IP address. - * Will be zero for CIDR/31 and CIDR/32 if the inclusive flag is false. - * - * @return the IP address in dotted format, may be "0.0.0.0" if there is no valid address - */ - public String getHighAddress() { - return format(toArray(high())); - } - - /** - * Get the count of available addresses. - * Will be zero for CIDR/31 and CIDR/32 if the inclusive flag is false. - * @return the count of addresses, may be zero. - * @throws RuntimeException if the correct count is greater than {@code Integer.MAX_VALUE} - * @deprecated (3.4) use {@link #getAddressCountLong()} instead - */ - @Deprecated - public int getAddressCount() { - long countLong = getAddressCountLong(); - if (countLong > Integer.MAX_VALUE) { - throw new RuntimeException("Count is larger than an integer: " + countLong); - } - // N.B. cannot be negative - return (int)countLong; - } - - /** - * Get the count of available addresses. - * Will be zero for CIDR/31 and CIDR/32 if the inclusive flag is false. - * @return the count of addresses, may be zero. - * @since 3.4 - */ - public long getAddressCountLong() { - long b = broadcastLong(); - long n = networkLong(); - long count = b - n + (isInclusiveHostCount() ? 1 : -1); - return count < 0 ? 0 : count; - } - - public int asInteger(String address) { - return toInteger(address); - } - - public String getCidrSignature() { - return toCidrNotation( - format(toArray(address())), - format(toArray(netmask())) - ); - } - - public String[] getAllAddresses() { - int ct = getAddressCount(); - String[] addresses = new String[ct]; - if (ct == 0) { - return addresses; - } - for (int add = low(), j=0; add <= high(); ++add, ++j) { - addresses[j] = format(toArray(add)); - } - return addresses; - } - - /** - * {@inheritDoc} - * @since 2.2 - */ - @Override - public String toString() { - final StringBuilder buf = new StringBuilder(); - buf.append("CIDR Signature:\t[").append(getCidrSignature()).append("]") - .append(" Netmask: [").append(getNetmask()).append("]\n") - .append("Network:\t[").append(getNetworkAddress()).append("]\n") - .append("Broadcast:\t[").append(getBroadcastAddress()).append("]\n") - .append("First Address:\t[").append(getLowAddress()).append("]\n") - .append("Last Address:\t[").append(getHighAddress()).append("]\n") - .append("# Addresses:\t[").append(getAddressCount()).append("]\n"); - return buf.toString(); - } - } - - /** - * Return a {@link SubnetInfo} instance that contains subnet-specific statistics - * @return new instance - */ - public final SubnetInfo getInfo() { return new SubnetInfo(); } - - /* - * Initialize the internal fields from the supplied CIDR mask - */ - private void calculate(String mask) { - Matcher matcher = cidrPattern.matcher(mask); - - if (matcher.matches()) { - address = matchAddress(matcher); - - /* Create a binary netmask from the number of bits specification /x */ - int cidrPart = rangeCheck(Integer.parseInt(matcher.group(5)), 0, NBITS); - for (int j = 0; j < cidrPart; ++j) { - netmask |= (1 << 31 - j); - } - - /* Calculate base network address */ - network = (address & netmask); - - /* Calculate broadcast address */ - broadcast = network | ~(netmask); - } else { - throw new IllegalArgumentException("Could not parse [" + mask + "]"); - } - } - - /* - * Convert a dotted decimal format address to a packed integer format - */ - private int toInteger(String address) { - Matcher matcher = addressPattern.matcher(address); - if (matcher.matches()) { - return matchAddress(matcher); - } else { - throw new IllegalArgumentException("Could not parse [" + address + "]"); - } - } - - /* - * Convenience method to extract the components of a dotted decimal address and - * pack into an integer using a regex match - */ - private int matchAddress(Matcher matcher) { - int addr = 0; - for (int i = 1; i <= 4; ++i) { - int n = (rangeCheck(Integer.parseInt(matcher.group(i)), 0, 255)); - addr |= ((n & 0xff) << 8*(4-i)); - } - return addr; - } - - /* - * Convert a packed integer address into a 4-element array - */ - private int[] toArray(int val) { - int ret[] = new int[4]; - for (int j = 3; j >= 0; --j) { - ret[j] |= ((val >>> 8*(3-j)) & (0xff)); - } - return ret; - } - - /* - * Convert a 4-element array into dotted decimal format - */ - private String format(int[] octets) { - StringBuilder str = new StringBuilder(); - for (int i =0; i < octets.length; ++i){ - str.append(octets[i]); - if (i != octets.length - 1) { - str.append("."); - } - } - return str.toString(); - } - - /* - * Convenience function to check integer boundaries. - * Checks if a value x is in the range [begin,end]. - * Returns x if it is in range, throws an exception otherwise. - */ - private int rangeCheck(int value, int begin, int end) { - if (value >= begin && value <= end) { // (begin,end] - return value; - } - - throw new IllegalArgumentException("Value [" + value + "] not in range ["+begin+","+end+"]"); - } - - /* - * Count the number of 1-bits in a 32-bit integer using a divide-and-conquer strategy - * see Hacker's Delight section 5.1 - */ - int pop(int x) { - x = x - ((x >>> 1) & 0x55555555); - x = (x & 0x33333333) + ((x >>> 2) & 0x33333333); - x = (x + (x >>> 4)) & 0x0F0F0F0F; - x = x + (x >>> 8); - x = x + (x >>> 16); - return x & 0x0000003F; - } - - /* Convert two dotted decimal addresses to a single xxx.xxx.xxx.xxx/yy format - * by counting the 1-bit population in the mask address. (It may be better to count - * NBITS-#trailing zeroes for this case) - */ - private String toCidrNotation(String addr, String mask) { - return addr + "/" + pop(toInteger(mask)); - } -} diff --git a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/handler/predicate/RemoteAddrRoutePredicateFactoryTests.java b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/handler/predicate/RemoteAddrRoutePredicateFactoryTests.java new file mode 100644 index 000000000..3523fb15d --- /dev/null +++ b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/handler/predicate/RemoteAddrRoutePredicateFactoryTests.java @@ -0,0 +1,53 @@ +package org.springframework.cloud.gateway.handler.predicate; + +import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT; +import static org.springframework.cloud.gateway.test.TestUtils.assertStatus; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.cloud.gateway.test.BaseWebClientTests; +import org.springframework.context.annotation.Import; +import org.springframework.http.HttpStatus; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.web.reactive.function.client.ClientResponse; + +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +@RunWith(SpringRunner.class) +@SpringBootTest(webEnvironment = RANDOM_PORT) +@DirtiesContext +@ActiveProfiles({ "remote-address" }) +public class RemoteAddrRoutePredicateFactoryTests extends BaseWebClientTests { + + @Test + public void pathRouteWorks() { + Mono result = webClient.get().uri("/ok/httpbin/").exchange(); + + StepVerifier.create(result) + .consumeNextWith(response -> assertStatus(response, HttpStatus.OK)) + .expectComplete().verify(DURATION); + } + + @Test + public void pathRouteDoNotWork() { + Mono result = webClient.get().uri("/nok/httpbin/").exchange(); + + StepVerifier + .create(result) + .consumeNextWith(response -> assertStatus(response, HttpStatus.NOT_FOUND)) + .expectComplete().verify(DURATION); + } + + @EnableAutoConfiguration + @SpringBootConfiguration + @Import(DefaultTestConfig.class) + public static class TestConfig { + } + +} diff --git a/spring-cloud-gateway-core/src/test/resources/application-remote-address.yml b/spring-cloud-gateway-core/src/test/resources/application-remote-address.yml new file mode 100644 index 000000000..0fdef541f --- /dev/null +++ b/spring-cloud-gateway-core/src/test/resources/application-remote-address.yml @@ -0,0 +1,27 @@ +test: + uri: lb://myservice + +spring: + cloud: + gateway: + default-filters: + routes: + # ===================================== + - id: remote_address_ipv6_test + uri: ${test.uri} + predicates: + - Path=/ok/httpbin/ + - RemoteAddr=2001:db8:abcd:0012::0/64,::1/32,127.0.0.1 + filters: + - SetPath=/httpbin/ + - SetStatus=200 + + # ===================================== + - id: remote_address_ipv6_test_other_ip + uri: ${test.uri} + predicates: + - Path=/nok/httpbin/ + - RemoteAddr=2001:db8:abcd:0012::0/64 + filters: + - SetPath=/httpbin/ + - SetStatus=200