Browse Source

Support for SslInfo in ServerHttpRequest#mutate

Issue: SPR-16830
pull/1838/head
Rossen Stoyanchev 7 years ago
parent
commit
e3e975d7f9
  1. 23
      spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java
  2. 44
      spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java
  3. 77
      spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java

23
spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java

@ -57,6 +57,9 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder {
@Nullable @Nullable
private String contextPath; private String contextPath;
@Nullable
private SslInfo sslInfo;
private Flux<DataBuffer> body; private Flux<DataBuffer> body;
private final ServerHttpRequest originalRequest; private final ServerHttpRequest originalRequest;
@ -97,6 +100,7 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder {
@Override @Override
public ServerHttpRequest.Builder path(String path) { public ServerHttpRequest.Builder path(String path) {
Assert.isTrue(path.startsWith("/"), "The path does not have a leading slash.");
this.uriPath = path; this.uriPath = path;
return this; return this;
} }
@ -120,10 +124,16 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder {
return this; return this;
} }
@Override
public ServerHttpRequest.Builder sslInfo(SslInfo sslInfo) {
this.sslInfo = sslInfo;
return this;
}
@Override @Override
public ServerHttpRequest build() { public ServerHttpRequest build() {
return new DefaultServerHttpRequest(getUriToUse(), this.contextPath, this.httpHeaders, return new MutatedServerHttpRequest(getUriToUse(), this.contextPath, this.httpHeaders,
this.httpMethodValue, this.cookies, this.body, this.originalRequest); this.httpMethodValue, this.cookies, this.sslInfo, this.body, this.originalRequest);
} }
private URI getUriToUse() { private URI getUriToUse() {
@ -165,7 +175,7 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder {
} }
private static class DefaultServerHttpRequest extends AbstractServerHttpRequest { private static class MutatedServerHttpRequest extends AbstractServerHttpRequest {
private final String methodValue; private final String methodValue;
@ -181,15 +191,16 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder {
private final ServerHttpRequest originalRequest; private final ServerHttpRequest originalRequest;
public DefaultServerHttpRequest(URI uri, @Nullable String contextPath,
public MutatedServerHttpRequest(URI uri, @Nullable String contextPath,
HttpHeaders headers, String methodValue, MultiValueMap<String, HttpCookie> cookies, HttpHeaders headers, String methodValue, MultiValueMap<String, HttpCookie> cookies,
Flux<DataBuffer> body, ServerHttpRequest originalRequest) { @Nullable SslInfo sslInfo, Flux<DataBuffer> body, ServerHttpRequest originalRequest) {
super(uri, contextPath, headers); super(uri, contextPath, headers);
this.methodValue = methodValue; this.methodValue = methodValue;
this.cookies = cookies; this.cookies = cookies;
this.remoteAddress = originalRequest.getRemoteAddress(); this.remoteAddress = originalRequest.getRemoteAddress();
this.sslInfo = originalRequest.getSslInfo(); this.sslInfo = sslInfo != null ? sslInfo : originalRequest.getSslInfo();
this.body = body; this.body = body;
this.originalRequest = originalRequest; this.originalRequest = originalRequest;
} }

44
spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java

@ -95,18 +95,36 @@ public interface ServerHttpRequest extends HttpRequest, ReactiveHttpInputMessage
Builder method(HttpMethod httpMethod); Builder method(HttpMethod httpMethod);
/** /**
* Set the URI to return. * Set the URI to use with the following conditions:
* <ul>
* <li>If {@link #path(String) path} is also set, it overrides the path
* of the URI provided here.
* <li>If {@link #contextPath(String) contextPath} is also set, or
* already present, it must match the start of the path of the URI
* provided here.
* </ul>
*/ */
Builder uri(URI uri); Builder uri(URI uri);
/** /**
* Set the path to use instead of the {@code "rawPath"} of * Set the path to use instead of the {@code "rawPath"} of the URI of
* {@link ServerHttpRequest#getURI()}. * the request with the following conditions:
* <ul>
* <li>If {@link #uri(URI) uri} is also set, the path given here
* overrides the path of the given URI.
* <li>If {@link #contextPath(String) contextPath} is also set, or
* already present, it must match the start of the path given here.
* <li>The given value must begin with a slash.
* </ul>
*/ */
Builder path(String path); Builder path(String path);
/** /**
* Set the contextPath to use. * Set the contextPath to use.
* <p>The given value must be a valid {@link RequestPath#contextPath()
* contextPath} and it must match the start of the path of the URI of
* the request. That means changing the contextPath, implies also
* changing the path via {@link #path(String)}.
*/ */
Builder contextPath(String contextPath); Builder contextPath(String contextPath);
@ -116,16 +134,22 @@ public interface ServerHttpRequest extends HttpRequest, ReactiveHttpInputMessage
Builder header(String key, String value); Builder header(String key, String value);
/** /**
* Manipulate this request's headers with the given consumer. The * Manipulate request headers. The provided {@code HttpHeaders} contains
* headers provided to the consumer are "live", so that the consumer can be used to * current request headers, so that the {@code Consumer} can
* {@linkplain HttpHeaders#set(String, String) overwrite} existing header values, * {@linkplain HttpHeaders#set(String, String) overwrite} or
* {@linkplain HttpHeaders#remove(Object) remove} values, or use any of the other * {@linkplain HttpHeaders#remove(Object) remove} existing values, or
* {@link HttpHeaders} methods. * use any other {@link HttpHeaders} methods.
* @param headersConsumer a function that consumes the {@code HttpHeaders}
* @return this builder
*/ */
Builder headers(Consumer<HttpHeaders> headersConsumer); Builder headers(Consumer<HttpHeaders> headersConsumer);
/**
* Set the SSL session information. This may be useful in environments
* where TLS termination is done at the router, but SSL information is
* made available in some other way such as through a header.
* @since 5.0.7
*/
Builder sslInfo(SslInfo sslInfo);
/** /**
* Build a {@link ServerHttpRequest} decorator with the mutated properties. * Build a {@link ServerHttpRequest} decorator with the mutated properties.
*/ */

77
spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java

@ -17,16 +17,17 @@
package org.springframework.http.server.reactive; package org.springframework.http.server.reactive;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.net.URI;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import javax.servlet.AsyncContext; import javax.servlet.AsyncContext;
import javax.servlet.ReadListener; import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream; import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import org.junit.Test; import org.junit.Test;
import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpMethod;
import org.springframework.mock.web.test.DelegatingServletInputStream; import org.springframework.mock.web.test.DelegatingServletInputStream;
import org.springframework.mock.web.test.MockAsyncContext; import org.springframework.mock.web.test.MockAsyncContext;
import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletRequest;
@ -34,6 +35,7 @@ import org.springframework.mock.web.test.MockHttpServletResponse;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
/** /**
* Unit tests for {@link AbstractServerHttpRequest}. * Unit tests for {@link AbstractServerHttpRequest}.
@ -84,33 +86,78 @@ public class ServerHttpRequestTests {
assertEquals(Collections.singletonList(null), params.get("a")); assertEquals(Collections.singletonList(null), params.get("a"));
} }
@Test
public void mutateRequest() throws Exception {
SslInfo sslInfo = mock(SslInfo.class);
ServerHttpRequest request = createHttpRequest("/").mutate().sslInfo(sslInfo).build();
assertSame(sslInfo, request.getSslInfo());
request = createHttpRequest("/").mutate().method(HttpMethod.DELETE).build();
assertEquals(HttpMethod.DELETE, request.getMethod());
String baseUri = "http://aaa.org:8080/a";
request = createHttpRequest(baseUri).mutate().uri(URI.create("http://bbb.org:9090/b")).build();
assertEquals("http://bbb.org:9090/b", request.getURI().toString());
request = createHttpRequest(baseUri).mutate().path("/b/c/d").build();
assertEquals("http://aaa.org:8080/b/c/d", request.getURI().toString());
request = createHttpRequest(baseUri).mutate().path("/app/b/c/d").contextPath("/app").build();
assertEquals("http://aaa.org:8080/app/b/c/d", request.getURI().toString());
assertEquals("/app", request.getPath().contextPath().value());
}
@Test(expected = IllegalArgumentException.class)
public void mutateWithInvalidPath() throws Exception {
createHttpRequest("/").mutate().path("foo-bar");
}
@Test // SPR-16434 @Test // SPR-16434
public void mutatePathWithEncodedQueryParams() throws Exception { public void mutatePathWithEncodedQueryParams() throws Exception {
ServerHttpRequest request = createHttpRequest("/path?name=%E6%89%8E%E6%A0%B9") ServerHttpRequest request = createHttpRequest("/path?name=%E6%89%8E%E6%A0%B9");
.mutate().path("/mutatedPath").build(); request = request.mutate().path("/mutatedPath").build();
assertEquals("/mutatedPath", request.getURI().getRawPath()); assertEquals("/mutatedPath", request.getURI().getRawPath());
assertEquals("name=%E6%89%8E%E6%A0%B9", request.getURI().getRawQuery()); assertEquals("name=%E6%89%8E%E6%A0%B9", request.getURI().getRawQuery());
} }
private ServerHttpRequest createHttpRequest(String uriString) throws Exception {
private ServerHttpRequest createHttpRequest(String path) throws Exception { URI uri = URI.create(uriString);
HttpServletRequest request = createEmptyBodyHttpServletRequest(path); MockHttpServletRequest request = new TestHttpServletRequest(uri);
AsyncContext asyncContext = new MockAsyncContext(request, new MockHttpServletResponse()); AsyncContext asyncContext = new MockAsyncContext(request, new MockHttpServletResponse());
return new ServletServerHttpRequest(request, asyncContext, "", new DefaultDataBufferFactory(), 1024); return new ServletServerHttpRequest(request, asyncContext, "", new DefaultDataBufferFactory(), 1024);
} }
private HttpServletRequest createEmptyBodyHttpServletRequest(String path) {
return new MockHttpServletRequest("GET", path) { private static class TestHttpServletRequest extends MockHttpServletRequest {
TestHttpServletRequest(URI uri) {
super("GET", uri.getRawPath());
if (uri.getScheme() != null) {
setScheme(uri.getScheme());
}
if (uri.getHost() != null) {
setServerName(uri.getHost());
}
if (uri.getPort() != -1) {
setServerPort(uri.getPort());
}
if (uri.getRawQuery() != null) {
setQueryString(uri.getRawQuery());
}
}
@Override
public ServletInputStream getInputStream() {
return new DelegatingServletInputStream(new ByteArrayInputStream(new byte[0])) {
@Override @Override
public ServletInputStream getInputStream() { public void setReadListener(ReadListener readListener) {
return new DelegatingServletInputStream(new ByteArrayInputStream(new byte[0])) { // Ignore
@Override
public void setReadListener(ReadListener readListener) {
// Ignore
}
};
} }
}; };
}
} }
} }

Loading…
Cancel
Save