diff --git a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/WebsocketRoutingFilter.java b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/WebsocketRoutingFilter.java index e84f68257..e0ed875e9 100644 --- a/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/WebsocketRoutingFilter.java +++ b/spring-cloud-gateway-core/src/main/java/org/springframework/cloud/gateway/filter/WebsocketRoutingFilter.java @@ -7,6 +7,8 @@ import java.util.Collections; import java.util.List; import java.util.stream.Collectors; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Mono; import org.springframework.beans.factory.ObjectProvider; @@ -19,6 +21,7 @@ import org.springframework.web.reactive.socket.WebSocketSession; import org.springframework.web.reactive.socket.client.WebSocketClient; import org.springframework.web.reactive.socket.server.WebSocketService; import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.UriComponentsBuilder; import static org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter.filterRequest; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR; @@ -30,6 +33,7 @@ import static org.springframework.util.StringUtils.commaDelimitedListToStringArr * @author Spencer Gibb */ public class WebsocketRoutingFilter implements GlobalFilter, Ordered { + private static final Log log = LogFactory.getLog(WebsocketRoutingFilter.class); public static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; private final WebSocketClient webSocketClient; @@ -46,14 +50,17 @@ public class WebsocketRoutingFilter implements GlobalFilter, Ordered { @Override public int getOrder() { - return Ordered.LOWEST_PRECEDENCE; + // Before NettyRoutingFilter since this routes certain http requests + return Ordered.LOWEST_PRECEDENCE - 1; } @Override public Mono filter(ServerWebExchange exchange, GatewayFilterChain chain) { - URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR); + changeSchemeIfIsWebSocketUpgrade(exchange); + URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR); String scheme = requestUrl.getScheme(); + if (isAlreadyRouted(exchange) || (!"ws".equals(scheme) && !"wss".equals(scheme))) { return chain.filter(exchange); } @@ -94,6 +101,24 @@ public class WebsocketRoutingFilter implements GlobalFilter, Ordered { return filters; } + private void changeSchemeIfIsWebSocketUpgrade(ServerWebExchange exchange) { + // Check the Upgrade + URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR); + String scheme = requestUrl.getScheme(); + String upgrade = exchange.getRequest().getHeaders().getUpgrade(); + // change the scheme if the socket client send a "http" or "https" + if ("WebSocket".equalsIgnoreCase(upgrade) && ("http".equals(scheme) || "https".equals(scheme))) { + String wsScheme = convertHttpToWs(scheme); + URI wsRequestUrl = UriComponentsBuilder.fromUri(requestUrl).scheme(wsScheme).build().toUri(); + exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, wsRequestUrl); + log.trace("changeSchemeTo:[" + wsRequestUrl+"]"); + } + } + + private String convertHttpToWs(String scheme) { + return "http".equals(scheme) ? "ws" : "https".equals(scheme) ? "wws" : scheme; + } + private static class ProxyWebSocketHandler implements WebSocketHandler { private final WebSocketClient client; @@ -126,10 +151,10 @@ public class WebsocketRoutingFilter implements GlobalFilter, Ordered { // Use retain() for Reactor Netty Mono proxySessionSend = proxySession .send(session.receive().doOnNext(WebSocketMessage::retain)); - // .log("proxySessionSend", Level.FINE); + // .log("proxySessionSend", Level.FINE); Mono serverSessionSend = session .send(proxySession.receive().doOnNext(WebSocketMessage::retain)); - // .log("sessionSend", Level.FINE); + // .log("sessionSend", Level.FINE); return Mono.when(proxySessionSend, serverSessionSend); } diff --git a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/test/websocket/WebSocketIntegrationTests.java b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/test/websocket/WebSocketIntegrationTests.java index e2416a3a2..953537afb 100644 --- a/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/test/websocket/WebSocketIntegrationTests.java +++ b/spring-cloud-gateway-core/src/test/java/org/springframework/cloud/gateway/test/websocket/WebSocketIntegrationTests.java @@ -141,7 +141,11 @@ public class WebSocketIntegrationTests { protected URI getUrl(String path) throws URISyntaxException { // return new URI("ws://localhost:" + this.serverPort + path); - return new URI("ws://localhost:" + this.gatewayPort + path); + return new URI("ws://localhost:" + this.gatewayPort + path); + } + + protected URI getHttpUrl(String path) throws URISyntaxException { + return new URI("http://localhost:" + this.gatewayPort + path); } @Configuration @@ -170,6 +174,7 @@ public class WebSocketIntegrationTests { public HandlerMapping handlerMapping() { Map map = new HashMap<>(); map.put("/echo", new EchoWebSocketHandler()); + map.put("/echoForHttp", new EchoWebSocketHandler()); map.put("/sub-protocol", new SubProtocolWebSocketHandler()); map.put("/custom-header", new CustomHeaderHandler()); @@ -203,6 +208,31 @@ public class WebSocketIntegrationTests { output.collectList().block(Duration.ofMillis(5000))); } + @Test + public void echoForHttp() throws Exception { + int count = 100; + Flux input = Flux.range(1, count).map(index -> "msg-" + index); + ReplayProcessor output = ReplayProcessor.create(count); + + client.execute(getHttpUrl("/echoForHttp"), + session -> { + logger.debug("Starting to send messages"); + return session + .send(input.doOnNext(s -> logger.debug("outbound " + s)).map(session::textMessage)) + .thenMany(session.receive().take(count).map(WebSocketMessage::getPayloadAsText)) + .subscribeWith(output) + .doOnNext(s -> logger.debug("inbound " + s)) + .then() + .doOnSuccessOrError((aVoid, ex) -> + logger.debug("Done with " + (ex != null ? ex.getMessage() : "success"))); + }) + .block(Duration.ofMillis(5000)); + + assertEquals(input.collectList().block(Duration.ofMillis(5000)), + output.collectList().block(Duration.ofMillis(5000))); + } + + @Test public void subProtocol() throws Exception { String protocol = "echo-v1"; @@ -316,8 +346,10 @@ public class WebSocketIntegrationTests { @Bean public RouteLocator wsRouteLocator(RouteLocatorBuilder builder) { return builder.routes() + .route(r->r.path("/echoForHttp") + .uri("lb://wsservice")) .route(r -> r.alwaysTrue() - .uri("lb:ws://wsservice")) + .uri("lb:ws://wsservice")) .build(); } }