Browse Source

Merge branch 'XMANcharles-master'

pull/275/head
Spencer Gibb 7 years ago
parent
commit
d6f64dd74f
No known key found for this signature in database
GPG Key ID: 7788A47380690861
  1. 33
      spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/WebsocketRoutingFilter.java
  2. 36
      spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/test/websocket/WebSocketIntegrationTests.java

33
spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/WebsocketRoutingFilter.java

@ -7,6 +7,8 @@ import java.util.Collections; @@ -7,6 +7,8 @@ import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Mono;
import org.springframework.beans.factory.ObjectProvider;
@ -19,6 +21,7 @@ import org.springframework.web.reactive.socket.WebSocketSession; @@ -19,6 +21,7 @@ import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import static org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter.filterRequest;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;
@ -30,6 +33,7 @@ import static org.springframework.util.StringUtils.commaDelimitedListToStringArr @@ -30,6 +33,7 @@ import static org.springframework.util.StringUtils.commaDelimitedListToStringArr
* @author Spencer Gibb
*/
public class WebsocketRoutingFilter implements GlobalFilter, Ordered {
private static final Log log = LogFactory.getLog(WebsocketRoutingFilter.class);
public static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
private final WebSocketClient webSocketClient;
@ -46,14 +50,17 @@ public class WebsocketRoutingFilter implements GlobalFilter, Ordered { @@ -46,14 +50,17 @@ public class WebsocketRoutingFilter implements GlobalFilter, Ordered {
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE;
// Before NettyRoutingFilter since this routes certain http requests
return Ordered.LOWEST_PRECEDENCE - 1;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
changeSchemeIfIsWebSocketUpgrade(exchange);
URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
String scheme = requestUrl.getScheme();
if (isAlreadyRouted(exchange) || (!"ws".equals(scheme) && !"wss".equals(scheme))) {
return chain.filter(exchange);
}
@ -94,6 +101,24 @@ public class WebsocketRoutingFilter implements GlobalFilter, Ordered { @@ -94,6 +101,24 @@ public class WebsocketRoutingFilter implements GlobalFilter, Ordered {
return filters;
}
private void changeSchemeIfIsWebSocketUpgrade(ServerWebExchange exchange) {
// Check the Upgrade
URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
String scheme = requestUrl.getScheme();
String upgrade = exchange.getRequest().getHeaders().getUpgrade();
// change the scheme if the socket client send a "http" or "https"
if ("WebSocket".equalsIgnoreCase(upgrade) && ("http".equals(scheme) || "https".equals(scheme))) {
String wsScheme = convertHttpToWs(scheme);
URI wsRequestUrl = UriComponentsBuilder.fromUri(requestUrl).scheme(wsScheme).build().toUri();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, wsRequestUrl);
log.trace("changeSchemeTo:[" + wsRequestUrl+"]");
}
}
private String convertHttpToWs(String scheme) {
return "http".equals(scheme) ? "ws" : "https".equals(scheme) ? "wws" : scheme;
}
private static class ProxyWebSocketHandler implements WebSocketHandler {
private final WebSocketClient client;
@ -126,10 +151,10 @@ public class WebsocketRoutingFilter implements GlobalFilter, Ordered { @@ -126,10 +151,10 @@ public class WebsocketRoutingFilter implements GlobalFilter, Ordered {
// Use retain() for Reactor Netty
Mono<Void> proxySessionSend = proxySession
.send(session.receive().doOnNext(WebSocketMessage::retain));
// .log("proxySessionSend", Level.FINE);
// .log("proxySessionSend", Level.FINE);
Mono<Void> serverSessionSend = session
.send(proxySession.receive().doOnNext(WebSocketMessage::retain));
// .log("sessionSend", Level.FINE);
// .log("sessionSend", Level.FINE);
return Mono.when(proxySessionSend, serverSessionSend);
}

36
spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/test/websocket/WebSocketIntegrationTests.java

@ -141,7 +141,11 @@ public class WebSocketIntegrationTests { @@ -141,7 +141,11 @@ public class WebSocketIntegrationTests {
protected URI getUrl(String path) throws URISyntaxException {
// return new URI("ws://localhost:" + this.serverPort + path);
return new URI("ws://localhost:" + this.gatewayPort + path);
return new URI("ws://localhost:" + this.gatewayPort + path);
}
protected URI getHttpUrl(String path) throws URISyntaxException {
return new URI("http://localhost:" + this.gatewayPort + path);
}
@Configuration
@ -170,6 +174,7 @@ public class WebSocketIntegrationTests { @@ -170,6 +174,7 @@ public class WebSocketIntegrationTests {
public HandlerMapping handlerMapping() {
Map<String, WebSocketHandler> map = new HashMap<>();
map.put("/echo", new EchoWebSocketHandler());
map.put("/echoForHttp", new EchoWebSocketHandler());
map.put("/sub-protocol", new SubProtocolWebSocketHandler());
map.put("/custom-header", new CustomHeaderHandler());
@ -203,6 +208,31 @@ public class WebSocketIntegrationTests { @@ -203,6 +208,31 @@ public class WebSocketIntegrationTests {
output.collectList().block(Duration.ofMillis(5000)));
}
@Test
public void echoForHttp() throws Exception {
int count = 100;
Flux<String> input = Flux.range(1, count).map(index -> "msg-" + index);
ReplayProcessor<Object> output = ReplayProcessor.create(count);
client.execute(getHttpUrl("/echoForHttp"),
session -> {
logger.debug("Starting to send messages");
return session
.send(input.doOnNext(s -> logger.debug("outbound " + s)).map(session::textMessage))
.thenMany(session.receive().take(count).map(WebSocketMessage::getPayloadAsText))
.subscribeWith(output)
.doOnNext(s -> logger.debug("inbound " + s))
.then()
.doOnSuccessOrError((aVoid, ex) ->
logger.debug("Done with " + (ex != null ? ex.getMessage() : "success")));
})
.block(Duration.ofMillis(5000));
assertEquals(input.collectList().block(Duration.ofMillis(5000)),
output.collectList().block(Duration.ofMillis(5000)));
}
@Test
public void subProtocol() throws Exception {
String protocol = "echo-v1";
@ -316,8 +346,10 @@ public class WebSocketIntegrationTests { @@ -316,8 +346,10 @@ public class WebSocketIntegrationTests {
@Bean
public RouteLocator wsRouteLocator(RouteLocatorBuilder builder) {
return builder.routes()
.route(r->r.path("/echoForHttp")
.uri("lb://wsservice"))
.route(r -> r.alwaysTrue()
.uri("lb:ws://wsservice"))
.uri("lb:ws://wsservice"))
.build();
}
}

Loading…
Cancel
Save