diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java b/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java index f357909a27..06c024fe50 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java @@ -57,6 +57,9 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder { @Nullable private String contextPath; + @Nullable + private SslInfo sslInfo; + private Flux body; private final ServerHttpRequest originalRequest; @@ -97,6 +100,7 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder { @Override public ServerHttpRequest.Builder path(String path) { + Assert.isTrue(path.startsWith("/"), "The path does not have a leading slash."); this.uriPath = path; return this; } @@ -120,10 +124,16 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder { return this; } + @Override + public ServerHttpRequest.Builder sslInfo(SslInfo sslInfo) { + this.sslInfo = sslInfo; + return this; + } + @Override public ServerHttpRequest build() { - return new DefaultServerHttpRequest(getUriToUse(), this.contextPath, this.httpHeaders, - this.httpMethodValue, this.cookies, this.body, this.originalRequest); + return new MutatedServerHttpRequest(getUriToUse(), this.contextPath, this.httpHeaders, + this.httpMethodValue, this.cookies, this.sslInfo, this.body, this.originalRequest); } private URI getUriToUse() { @@ -165,7 +175,7 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder { } - private static class DefaultServerHttpRequest extends AbstractServerHttpRequest { + private static class MutatedServerHttpRequest extends AbstractServerHttpRequest { private final String methodValue; @@ -181,15 +191,16 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder { private final ServerHttpRequest originalRequest; - public DefaultServerHttpRequest(URI uri, @Nullable String contextPath, + + public MutatedServerHttpRequest(URI uri, @Nullable String contextPath, HttpHeaders headers, String methodValue, MultiValueMap cookies, - Flux body, ServerHttpRequest originalRequest) { + @Nullable SslInfo sslInfo, Flux body, ServerHttpRequest originalRequest) { super(uri, contextPath, headers); this.methodValue = methodValue; this.cookies = cookies; this.remoteAddress = originalRequest.getRemoteAddress(); - this.sslInfo = originalRequest.getSslInfo(); + this.sslInfo = sslInfo != null ? sslInfo : originalRequest.getSslInfo(); this.body = body; this.originalRequest = originalRequest; } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java index 426ae7eb7e..73a9d78b1e 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java @@ -95,18 +95,36 @@ public interface ServerHttpRequest extends HttpRequest, ReactiveHttpInputMessage Builder method(HttpMethod httpMethod); /** - * Set the URI to return. + * Set the URI to use with the following conditions: + *
    + *
  • If {@link #path(String) path} is also set, it overrides the path + * of the URI provided here. + *
  • If {@link #contextPath(String) contextPath} is also set, or + * already present, it must match the start of the path of the URI + * provided here. + *
*/ Builder uri(URI uri); /** - * Set the path to use instead of the {@code "rawPath"} of - * {@link ServerHttpRequest#getURI()}. + * Set the path to use instead of the {@code "rawPath"} of the URI of + * the request with the following conditions: + *
    + *
  • If {@link #uri(URI) uri} is also set, the path given here + * overrides the path of the given URI. + *
  • If {@link #contextPath(String) contextPath} is also set, or + * already present, it must match the start of the path given here. + *
  • The given value must begin with a slash. + *
*/ Builder path(String path); /** * Set the contextPath to use. + *

The given value must be a valid {@link RequestPath#contextPath() + * contextPath} and it must match the start of the path of the URI of + * the request. That means changing the contextPath, implies also + * changing the path via {@link #path(String)}. */ Builder contextPath(String contextPath); @@ -116,16 +134,22 @@ public interface ServerHttpRequest extends HttpRequest, ReactiveHttpInputMessage Builder header(String key, String value); /** - * Manipulate this request's headers with the given consumer. The - * headers provided to the consumer are "live", so that the consumer can be used to - * {@linkplain HttpHeaders#set(String, String) overwrite} existing header values, - * {@linkplain HttpHeaders#remove(Object) remove} values, or use any of the other - * {@link HttpHeaders} methods. - * @param headersConsumer a function that consumes the {@code HttpHeaders} - * @return this builder + * Manipulate request headers. The provided {@code HttpHeaders} contains + * current request headers, so that the {@code Consumer} can + * {@linkplain HttpHeaders#set(String, String) overwrite} or + * {@linkplain HttpHeaders#remove(Object) remove} existing values, or + * use any other {@link HttpHeaders} methods. */ Builder headers(Consumer headersConsumer); + /** + * Set the SSL session information. This may be useful in environments + * where TLS termination is done at the router, but SSL information is + * made available in some other way such as through a header. + * @since 5.0.7 + */ + Builder sslInfo(SslInfo sslInfo); + /** * Build a {@link ServerHttpRequest} decorator with the mutated properties. */ diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java index a40e9bbd51..aba35e99c1 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java @@ -17,16 +17,17 @@ package org.springframework.http.server.reactive; import java.io.ByteArrayInputStream; +import java.net.URI; import java.util.Arrays; import java.util.Collections; import javax.servlet.AsyncContext; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; -import javax.servlet.http.HttpServletRequest; import org.junit.Test; import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpMethod; import org.springframework.mock.web.test.DelegatingServletInputStream; import org.springframework.mock.web.test.MockAsyncContext; import org.springframework.mock.web.test.MockHttpServletRequest; @@ -34,6 +35,7 @@ import org.springframework.mock.web.test.MockHttpServletResponse; import org.springframework.util.MultiValueMap; import static org.junit.Assert.*; +import static org.mockito.Mockito.*; /** * Unit tests for {@link AbstractServerHttpRequest}. @@ -84,33 +86,78 @@ public class ServerHttpRequestTests { assertEquals(Collections.singletonList(null), params.get("a")); } + @Test + public void mutateRequest() throws Exception { + + SslInfo sslInfo = mock(SslInfo.class); + ServerHttpRequest request = createHttpRequest("/").mutate().sslInfo(sslInfo).build(); + assertSame(sslInfo, request.getSslInfo()); + + request = createHttpRequest("/").mutate().method(HttpMethod.DELETE).build(); + assertEquals(HttpMethod.DELETE, request.getMethod()); + + String baseUri = "http://aaa.org:8080/a"; + + request = createHttpRequest(baseUri).mutate().uri(URI.create("http://bbb.org:9090/b")).build(); + assertEquals("http://bbb.org:9090/b", request.getURI().toString()); + + request = createHttpRequest(baseUri).mutate().path("/b/c/d").build(); + assertEquals("http://aaa.org:8080/b/c/d", request.getURI().toString()); + + request = createHttpRequest(baseUri).mutate().path("/app/b/c/d").contextPath("/app").build(); + assertEquals("http://aaa.org:8080/app/b/c/d", request.getURI().toString()); + assertEquals("/app", request.getPath().contextPath().value()); + } + + @Test(expected = IllegalArgumentException.class) + public void mutateWithInvalidPath() throws Exception { + createHttpRequest("/").mutate().path("foo-bar"); + } + @Test // SPR-16434 public void mutatePathWithEncodedQueryParams() throws Exception { - ServerHttpRequest request = createHttpRequest("/path?name=%E6%89%8E%E6%A0%B9") - .mutate().path("/mutatedPath").build(); + ServerHttpRequest request = createHttpRequest("/path?name=%E6%89%8E%E6%A0%B9"); + request = request.mutate().path("/mutatedPath").build(); + assertEquals("/mutatedPath", request.getURI().getRawPath()); assertEquals("name=%E6%89%8E%E6%A0%B9", request.getURI().getRawQuery()); } - - private ServerHttpRequest createHttpRequest(String path) throws Exception { - HttpServletRequest request = createEmptyBodyHttpServletRequest(path); + private ServerHttpRequest createHttpRequest(String uriString) throws Exception { + URI uri = URI.create(uriString); + MockHttpServletRequest request = new TestHttpServletRequest(uri); AsyncContext asyncContext = new MockAsyncContext(request, new MockHttpServletResponse()); return new ServletServerHttpRequest(request, asyncContext, "", new DefaultDataBufferFactory(), 1024); } - private HttpServletRequest createEmptyBodyHttpServletRequest(String path) { - return new MockHttpServletRequest("GET", path) { + + private static class TestHttpServletRequest extends MockHttpServletRequest { + + TestHttpServletRequest(URI uri) { + super("GET", uri.getRawPath()); + if (uri.getScheme() != null) { + setScheme(uri.getScheme()); + } + if (uri.getHost() != null) { + setServerName(uri.getHost()); + } + if (uri.getPort() != -1) { + setServerPort(uri.getPort()); + } + if (uri.getRawQuery() != null) { + setQueryString(uri.getRawQuery()); + } + } + + @Override + public ServletInputStream getInputStream() { + return new DelegatingServletInputStream(new ByteArrayInputStream(new byte[0])) { @Override - public ServletInputStream getInputStream() { - return new DelegatingServletInputStream(new ByteArrayInputStream(new byte[0])) { - @Override - public void setReadListener(ReadListener readListener) { - // Ignore - } - }; + public void setReadListener(ReadListener readListener) { + // Ignore } }; + } } }