From 1dd7d53de0b32813564d7b7c0627ade4d7f47b2b Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Fri, 19 Feb 2021 11:49:44 +0000 Subject: [PATCH] More precise mapping for WebSocket handshake requests Closes gh-26565 --- .../handler/AbstractUrlHandlerMapping.java | 35 +++++-- .../WebSocketUpgradeHandlerPredicate.java | 45 +++++++++ ...WebSocketUpgradeHandlerPredicateTests.java | 97 +++++++++++++++++++ .../support/WebSocketHandlerMapping.java | 49 +++++++++- .../support/WebSocketHandlerMappingTests.java | 70 +++++++++++++ 5 files changed, 284 insertions(+), 12 deletions(-) create mode 100644 spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/WebSocketUpgradeHandlerPredicate.java create mode 100644 spring-webflux/src/test/java/org/springframework/web/reactive/socket/server/support/WebSocketUpgradeHandlerPredicateTests.java create mode 100644 spring-websocket/src/test/java/org/springframework/web/socket/server/support/WebSocketHandlerMappingTests.java diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java index 9a93a65a4d..abaf4cf86f 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.function.BiPredicate; import reactor.core.publisher.Mono; @@ -57,6 +58,9 @@ public abstract class AbstractUrlHandlerMapping extends AbstractHandlerMapping { private final Map handlerMap = new LinkedHashMap<>(); + @Nullable + private BiPredicate handlerPredicate; + /** * Set whether to lazily initialize handlers. Only applicable to @@ -81,6 +85,23 @@ public abstract class AbstractUrlHandlerMapping extends AbstractHandlerMapping { return Collections.unmodifiableMap(this.handlerMap); } + /** + * Configure a predicate for extended matching of the handler that was + * matched by URL path. This allows for further narrowing of the mapping by + * checking additional properties of the request. If the predicate returns + * "false", it result in a no-match, which allows another + * {@link org.springframework.web.reactive.HandlerMapping} to match or + * result in a 404 (NOT_FOUND) response. + * @param handlerPredicate a bi-predicate to match the candidate handler + * against the current exchange. + * @since 5.3.5 + * @see org.springframework.web.reactive.socket.server.support.WebSocketUpgradeHandlerPredicate + */ + public void setHandlerPredicate(BiPredicate handlerPredicate) { + this.handlerPredicate = (this.handlerPredicate != null ? + this.handlerPredicate.and(handlerPredicate) : handlerPredicate); + } + @Override public Mono getHandlerInternal(ServerWebExchange exchange) { @@ -129,11 +150,7 @@ public abstract class AbstractUrlHandlerMapping extends AbstractHandlerMapping { PathPattern.PathMatchInfo matchInfo = pattern.matchAndExtract(lookupPath); Assert.notNull(matchInfo, "Expected a match"); - return handleMatch(this.handlerMap.get(pattern), pattern, pathWithinMapping, matchInfo, exchange); - } - - private Object handleMatch(Object handler, PathPattern bestMatch, PathContainer pathWithinMapping, - PathPattern.PathMatchInfo matchInfo, ServerWebExchange exchange) { + Object handler = this.handlerMap.get(pattern); // Bean name or resolved handler? if (handler instanceof String) { @@ -141,10 +158,14 @@ public abstract class AbstractUrlHandlerMapping extends AbstractHandlerMapping { handler = obtainApplicationContext().getBean(handlerName); } + if (this.handlerPredicate != null && !this.handlerPredicate.test(handler, exchange)) { + return null; + } + validateHandler(handler, exchange); exchange.getAttributes().put(BEST_MATCHING_HANDLER_ATTRIBUTE, handler); - exchange.getAttributes().put(BEST_MATCHING_PATTERN_ATTRIBUTE, bestMatch); + exchange.getAttributes().put(BEST_MATCHING_PATTERN_ATTRIBUTE, pattern); exchange.getAttributes().put(PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, pathWithinMapping); exchange.getAttributes().put(URI_TEMPLATE_VARIABLES_ATTRIBUTE, matchInfo.getUriVariables()); diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/WebSocketUpgradeHandlerPredicate.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/WebSocketUpgradeHandlerPredicate.java new file mode 100644 index 0000000000..c89312246d --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/WebSocketUpgradeHandlerPredicate.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.reactive.socket.server.support; + +import java.util.function.BiPredicate; + +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.server.ServerWebExchange; + +/** + * A predicate for use with + * {@link org.springframework.web.reactive.handler.AbstractUrlHandlerMapping#setHandlerPredicate(BiPredicate)} + * to ensure only WebSocket handshake requests are matched to handlers of + * type {@link WebSocketHandler}. + * + * @author Rossen Stoyanchev + * @since 5.3.5 + */ +public class WebSocketUpgradeHandlerPredicate implements BiPredicate { + + + @Override + public boolean test(Object handler, ServerWebExchange exchange) { + if (handler instanceof WebSocketHandler) { + String method = exchange.getRequest().getMethodValue(); + String header = exchange.getRequest().getHeaders().getUpgrade(); + return (method.equals("GET") && header != null && header.equalsIgnoreCase("websocket")); + } + return true; + } + +} diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/socket/server/support/WebSocketUpgradeHandlerPredicateTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/server/support/WebSocketUpgradeHandlerPredicateTests.java new file mode 100644 index 0000000000..2e06acd072 --- /dev/null +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/socket/server/support/WebSocketUpgradeHandlerPredicateTests.java @@ -0,0 +1,97 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.reactive.socket.server.support; + +import java.util.Collections; + +import org.junit.jupiter.api.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.web.context.support.StaticWebApplicationContext; +import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping; +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest; +import org.springframework.web.testfixture.server.MockServerWebExchange; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for and related to the use of {@link WebSocketUpgradeHandlerPredicate}. + * + * @author Rossen Stoyanchev + */ +public class WebSocketUpgradeHandlerPredicateTests { + + private final WebSocketUpgradeHandlerPredicate predicate = new WebSocketUpgradeHandlerPredicate(); + + private final WebSocketHandler webSocketHandler = mock(WebSocketHandler.class); + + ServerWebExchange httpGetExchange = + MockServerWebExchange.from(MockServerHttpRequest.get("/path")); + + ServerWebExchange httpPostExchange = + MockServerWebExchange.from(MockServerHttpRequest.post("/path")); + + ServerWebExchange webSocketExchange = + MockServerWebExchange.from(MockServerHttpRequest.get("/path").header(HttpHeaders.UPGRADE, "websocket")); + + + @Test + void match() { + assertThat(this.predicate.test(this.webSocketHandler, this.webSocketExchange)) + .as("Should match WebSocketHandler to WebSocket upgrade") + .isTrue(); + + assertThat(this.predicate.test(new Object(), this.httpGetExchange)) + .as("Should match non-WebSocketHandler to any request") + .isTrue(); + } + + @Test + void noMatch() { + assertThat(this.predicate.test(this.webSocketHandler, this.httpGetExchange)) + .as("Should not match WebSocket handler to HTTP GET") + .isFalse(); + + assertThat(this.predicate.test(this.webSocketHandler, this.httpPostExchange)) + .as("Should not match WebSocket handler to HTTP POST") + .isFalse(); + } + + @Test + void simpleUrlHandlerMapping() { + SimpleUrlHandlerMapping mapping = new SimpleUrlHandlerMapping(); + mapping.setUrlMap(Collections.singletonMap("/path", this.webSocketHandler)); + mapping.setApplicationContext(new StaticWebApplicationContext()); + + Object actual = mapping.getHandler(httpGetExchange).block(); + assertThat(actual).as("Should match HTTP GET by URL path").isSameAs(this.webSocketHandler); + + mapping.setHandlerPredicate(new WebSocketUpgradeHandlerPredicate()); + + actual = mapping.getHandler(this.httpGetExchange).block(); + assertThat(actual).as("Should not match if not a WebSocket upgrade").isNull(); + + actual = mapping.getHandler(this.httpPostExchange).block(); + assertThat(actual).as("Should not match if not a WebSocket upgrade").isNull(); + + actual = mapping.getHandler(this.webSocketExchange).block(); + assertThat(actual).as("Should match WebSocket upgrade").isSameAs(this.webSocketHandler); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHandlerMapping.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHandlerMapping.java index 6d46b2d936..1ce49177a5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHandlerMapping.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHandlerMapping.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,26 +17,47 @@ package org.springframework.web.socket.server.support; import javax.servlet.ServletContext; +import javax.servlet.http.HttpServletRequest; import org.springframework.context.Lifecycle; import org.springframework.context.SmartLifecycle; +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; import org.springframework.web.context.ServletContextAware; +import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; /** - * An extension of {@link SimpleUrlHandlerMapping} that is also a - * {@link SmartLifecycle} container and propagates start and stop calls to any - * handlers that implement {@link Lifecycle}. The handlers are typically expected - * to be {@code WebSocketHttpRequestHandler} or {@code SockJsHttpRequestHandler}. + * Extension of {@link SimpleUrlHandlerMapping} with support for more + * precise mapping of WebSocket handshake requests to handlers of type + * {@link WebSocketHttpRequestHandler}. Also delegates {@link Lifecycle} + * methods to handlers in the {@link #getUrlMap()} that implement it. * * @author Rossen Stoyanchev * @since 4.2 */ public class WebSocketHandlerMapping extends SimpleUrlHandlerMapping implements SmartLifecycle { + private boolean webSocketUpgradeMatch; + private volatile boolean running; + /** + * When this is set, if the matched handler is + * {@link WebSocketHttpRequestHandler}, ensure the request is a WebSocket + * handshake, i.e. HTTP GET with the header {@code "Upgrade:websocket"}, + * or otherwise suppress the match and return {@code null} allowing another + * {@link org.springframework.web.servlet.HandlerMapping} to match for the + * same URL path. + * @param match whether to enable matching on {@code "Upgrade: websocket"} + * @since 5.3.5 + */ + public void setWebSocketUpgradeMatch(boolean match) { + this.webSocketUpgradeMatch = match; + } + + @Override protected void initServletContext(ServletContext servletContext) { for (Object handler : getUrlMap().values()) { @@ -76,4 +97,22 @@ public class WebSocketHandlerMapping extends SimpleUrlHandlerMapping implements return this.running; } + @Nullable + @Override + protected Object getHandlerInternal(HttpServletRequest request) throws Exception { + Object handler = super.getHandlerInternal(request); + return matchWebSocketUpgrade(handler, request) ? handler : null; + } + + private boolean matchWebSocketUpgrade(@Nullable Object handler, HttpServletRequest request) { + handler = (handler instanceof HandlerExecutionChain ? + ((HandlerExecutionChain) handler).getHandler() : handler); + if (this.webSocketUpgradeMatch && handler instanceof WebSocketHttpRequestHandler) { + String header = request.getHeader(HttpHeaders.UPGRADE); + return (request.getMethod().equals("GET") && + header != null && header.equalsIgnoreCase("websocket")); + } + return true; + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/WebSocketHandlerMappingTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/WebSocketHandlerMappingTests.java new file mode 100644 index 0000000000..b5ec3cdfdb --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/WebSocketHandlerMappingTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.socket.server.support; + +import java.util.Collections; + +import org.junit.jupiter.api.Test; + +import org.springframework.web.HttpRequestHandler; +import org.springframework.web.context.support.StaticWebApplicationContext; +import org.springframework.web.servlet.HandlerExecutionChain; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.testfixture.servlet.MockHttpServletRequest; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link WebSocketHandlerMapping}. + * + * @author Rossen Stoyanchev + */ +public class WebSocketHandlerMappingTests { + + + @Test + void webSocketHandshakeMatch() throws Exception { + HttpRequestHandler handler = new WebSocketHttpRequestHandler(mock(WebSocketHandler.class)); + + WebSocketHandlerMapping mapping = new WebSocketHandlerMapping(); + mapping.setUrlMap(Collections.singletonMap("/path", handler)); + mapping.setApplicationContext(new StaticWebApplicationContext()); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/path"); + + HandlerExecutionChain chain = mapping.getHandler(request); + assertThat(chain).isNotNull(); + assertThat(chain.getHandler()).isSameAs(handler); + + mapping.setWebSocketUpgradeMatch(true); + + chain = mapping.getHandler(request); + assertThat(chain).isNull(); + + request.addHeader("Upgrade", "websocket"); + + chain = mapping.getHandler(request); + assertThat(chain).isNotNull(); + assertThat(chain.getHandler()).isSameAs(handler); + + request.setMethod("POST"); + + chain = mapping.getHandler(request); + assertThat(chain).isNull(); + } + +}