From 6b410df45b6dc4824162996a00c5360ab1563926 Mon Sep 17 00:00:00 2001 From: Sam Brannen Date: Fri, 17 Apr 2020 15:01:35 +0200 Subject: [PATCH] Retain brackets for IPV6 address in MockHttpServletRequest According to the Javadoc for ServletRequest's getServerName() method, when the `Host` header is set, the server name is "the value of the part before ':' in the Host header value ...". For a value representing an IPV6 address such as `[::ffff:abcd:abcd]`, the enclosing square brackets should therefore not be stripped from the enclosed IPV6 address. However, the changes made in conjunction with gh-16704 introduced a regression in Spring Framework 4.1 for the getServerName() method in MockHttpServletRequest by stripping the enclosing brackets from the IPV6 address in the `Host` header. Similarly, the changes made in conjunction with gh-20686 introduced a regression in Spring Framework 4.3.13 and 5.0.2 in the getRequestURL() method in MockHttpServletRequest by delegating to the getServerName() method which strips the enclosing brackets. This commit fixes the implementation of getServerName() so that the enclosing brackets are no longer stripped from an IPV6 address in the `Host` header. The implementation of getRequestURL() is therefore also fixed. In addition, in order to avoid a NullPointerException, the implementations of getServerName() and getServerPort() now assert that an IPV6 address present in the `Host` header correctly contains an opening and closing bracket and throw an IllegalStateException if that is not the case. Closes gh-24916 --- .../mock/web/MockHttpServletRequest.java | 14 +++- .../mock/web/MockHttpServletRequestTests.java | 76 +++++++++++++++++-- 2 files changed, 79 insertions(+), 11 deletions(-) diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java index ac8f194376..8e9961208c 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockHttpServletRequest.java @@ -668,11 +668,14 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public String getServerName() { - String host = getHeader(HttpHeaders.HOST); + String rawHostHeader = getHeader(HttpHeaders.HOST); + String host = rawHostHeader; if (host != null) { host = host.trim(); if (host.startsWith("[")) { - host = host.substring(1, host.indexOf(']')); + int indexOfClosingBracket = host.indexOf(']'); + Assert.state(indexOfClosingBracket > -1, () -> "Invalid Host header: " + rawHostHeader); + host = host.substring(0, indexOfClosingBracket + 1); } else if (host.contains(":")) { host = host.substring(0, host.indexOf(':')); @@ -690,12 +693,15 @@ public class MockHttpServletRequest implements HttpServletRequest { @Override public int getServerPort() { - String host = getHeader(HttpHeaders.HOST); + String rawHostHeader = getHeader(HttpHeaders.HOST); + String host = rawHostHeader; if (host != null) { host = host.trim(); int idx; if (host.startsWith("[")) { - idx = host.indexOf(':', host.indexOf(']')); + int indexOfClosingBracket = host.indexOf(']'); + Assert.state(indexOfClosingBracket > -1, () -> "Invalid Host header: " + rawHostHeader); + idx = host.indexOf(':', indexOfClosingBracket); } else { idx = host.indexOf(':'); diff --git a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java index 9f58f85922..76104fbb6f 100644 --- a/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java +++ b/spring-test/src/test/java/org/springframework/mock/web/MockHttpServletRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.mock.web; import java.io.IOException; +import java.net.URL; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; @@ -37,6 +38,7 @@ import org.springframework.util.FileCopyUtils; import org.springframework.util.StreamUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; @@ -389,18 +391,25 @@ class MockHttpServletRequestTests { assertThat(request.getServerName()).isEqualTo(testServer); } + @Test + void getServerNameWithInvalidIpv6AddressViaHostHeader() { + request.addHeader(HOST, "[::ffff:abcd:abcd"); // missing closing bracket + assertThatIllegalStateException() + .isThrownBy(() -> request.getServerName()) + .withMessageStartingWith("Invalid Host header: "); + } + @Test void getServerNameViaHostHeaderAsIpv6AddressWithoutPort() { - String ipv6Address = "[2001:db8:0:1]"; - request.addHeader(HOST, ipv6Address); - assertThat(request.getServerName()).isEqualTo("2001:db8:0:1"); + String host = "[2001:db8:0:1]"; + request.addHeader(HOST, host); + assertThat(request.getServerName()).isEqualTo(host); } @Test void getServerNameViaHostHeaderAsIpv6AddressWithPort() { - String ipv6Address = "[2001:db8:0:1]:8081"; - request.addHeader(HOST, ipv6Address); - assertThat(request.getServerName()).isEqualTo("2001:db8:0:1"); + request.addHeader(HOST, "[2001:db8:0:1]:8081"); + assertThat(request.getServerName()).isEqualTo("[2001:db8:0:1]"); } @Test @@ -414,6 +423,22 @@ class MockHttpServletRequestTests { assertThat(request.getServerPort()).isEqualTo(8080); } + @Test + void getServerPortWithInvalidIpv6AddressViaHostHeader() { + request.addHeader(HOST, "[::ffff:abcd:abcd:8080"); // missing closing bracket + assertThatIllegalStateException() + .isThrownBy(() -> request.getServerPort()) + .withMessageStartingWith("Invalid Host header: "); + } + + @Test + void getServerPortWithIpv6AddressAndInvalidPortViaHostHeader() { + request.addHeader(HOST, "[::ffff:abcd:abcd]:bogus"); // "bogus" is not a port number + assertThatExceptionOfType(NumberFormatException.class) + .isThrownBy(() -> request.getServerPort()) + .withMessageContaining("bogus"); + } + @Test void getServerPortViaHostHeaderAsIpv6AddressWithoutPort() { String testServer = "[2001:db8:0:1]"; @@ -478,6 +503,43 @@ class MockHttpServletRequestTests { assertThat(requestURL.toString()).isEqualTo(("http://" + testServer)); } + @Test + void getRequestURLWithIpv6AddressViaServerNameWithoutPort() throws Exception { + request.setServerName("[::ffff:abcd:abcd]"); + URL url = new java.net.URL(request.getRequestURL().toString()); + assertThat(url).asString().isEqualTo("http://[::ffff:abcd:abcd]"); + } + + @Test + void getRequestURLWithIpv6AddressViaServerNameWithPort() throws Exception { + request.setServerName("[::ffff:abcd:abcd]"); + request.setServerPort(9999); + URL url = new java.net.URL(request.getRequestURL().toString()); + assertThat(url).asString().isEqualTo("http://[::ffff:abcd:abcd]:9999"); + } + + @Test + void getRequestURLWithInvalidIpv6AddressViaHostHeader() { + request.addHeader(HOST, "[::ffff:abcd:abcd"); // missing closing bracket + assertThatIllegalStateException() + .isThrownBy(() -> request.getRequestURL()) + .withMessageStartingWith("Invalid Host header: "); + } + + @Test + void getRequestURLWithIpv6AddressViaHostHeaderWithoutPort() throws Exception { + request.addHeader(HOST, "[::ffff:abcd:abcd]"); + URL url = new java.net.URL(request.getRequestURL().toString()); + assertThat(url).asString().isEqualTo("http://[::ffff:abcd:abcd]"); + } + + @Test + void getRequestURLWithIpv6AddressViaHostHeaderWithPort() throws Exception { + request.addHeader(HOST, "[::ffff:abcd:abcd]:9999"); + URL url = new java.net.URL(request.getRequestURL().toString()); + assertThat(url).asString().isEqualTo("http://[::ffff:abcd:abcd]:9999"); + } + @Test void getRequestURLWithNullRequestUri() { request.setRequestURI(null);