diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java index 946f061af3..34e6b89f01 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java @@ -16,7 +16,6 @@ package org.springframework.web.socket.server.jetty; -import java.lang.reflect.Method; import java.lang.reflect.UndeclaredThrowableException; import java.security.Principal; import java.util.Collections; @@ -26,19 +25,15 @@ import java.util.Map; import jakarta.servlet.ServletContext; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; -import org.aopalliance.intercept.MethodInterceptor; -import org.aopalliance.intercept.MethodInvocation; +import org.eclipse.jetty.websocket.server.JettyWebSocketCreator; +import org.eclipse.jetty.websocket.server.JettyWebSocketServerContainer; -import org.springframework.aop.framework.ProxyFactory; -import org.springframework.aop.target.EmptyTargetSource; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; -import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import org.springframework.util.ReflectionUtils; import org.springframework.web.socket.WebSocketExtension; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.adapter.jetty.JettyWebSocketHandlerAdapter; @@ -56,35 +51,6 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { private static final String[] SUPPORTED_VERSIONS = new String[] { String.valueOf(13) }; - private static final Class webSocketCreatorClass; - - private static final Method getContainerMethod; - - private static final Method upgradeMethod; - - private static final Method setAcceptedSubProtocol; - - static { - // TODO: can switch to non-reflective implementation now - - ClassLoader loader = JettyRequestUpgradeStrategy.class.getClassLoader(); - try { - webSocketCreatorClass = loader.loadClass("org.eclipse.jetty.websocket.server.JettyWebSocketCreator"); - - Class type = loader.loadClass("org.eclipse.jetty.websocket.server.JettyWebSocketServerContainer"); - getContainerMethod = type.getMethod("getContainer", ServletContext.class); - Method upgrade = ReflectionUtils.findMethod(type, "upgrade", (Class[]) null); - Assert.state(upgrade != null, "Upgrade method not found"); - upgradeMethod = upgrade; - - type = loader.loadClass("org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse"); - setAcceptedSubProtocol = type.getMethod("setAcceptedSubProtocol", String.class); - } - catch (Exception ex) { - throw new IllegalStateException("No compatible Jetty version found", ex); - } - } - @Override public String[] getSupportedVersions() { @@ -113,10 +79,17 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { JettyWebSocketSession session = new JettyWebSocketSession(attributes, user); JettyWebSocketHandlerAdapter handlerAdapter = new JettyWebSocketHandlerAdapter(handler, session); + JettyWebSocketCreator webSocketCreator = (upgradeRequest, upgradeResponse) -> { + if (selectedProtocol != null) { + upgradeResponse.setAcceptedSubProtocol(selectedProtocol); + } + return handlerAdapter; + }; + + JettyWebSocketServerContainer container = JettyWebSocketServerContainer.getContainer(servletContext); + try { - Object creator = createJettyWebSocketCreator(handlerAdapter, selectedProtocol); - Object container = ReflectionUtils.invokeMethod(getContainerMethod, null, servletContext); - ReflectionUtils.invokeMethod(upgradeMethod, container, creator, servletRequest, servletResponse); + container.upgrade(webSocketCreator, servletRequest, servletResponse); } catch (UndeclaredThrowableException ex) { throw new HandshakeFailureException("Failed to upgrade", ex.getUndeclaredThrowable()); @@ -126,40 +99,5 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy { } } - private static Object createJettyWebSocketCreator( - JettyWebSocketHandlerAdapter adapter, @Nullable String protocol) { - - ProxyFactory factory = new ProxyFactory(EmptyTargetSource.INSTANCE); - factory.addInterface(webSocketCreatorClass); - factory.addAdvice(new WebSocketCreatorInterceptor(adapter, protocol)); - return factory.getProxy(); - } - - - /** - * Proxy for a JettyWebSocketCreator to supply the WebSocket handler and set the sub-protocol. - */ - private static class WebSocketCreatorInterceptor implements MethodInterceptor { - - private final JettyWebSocketHandlerAdapter adapter; - - @Nullable - private final String protocol; - - public WebSocketCreatorInterceptor(JettyWebSocketHandlerAdapter adapter, @Nullable String protocol) { - this.adapter = adapter; - this.protocol = protocol; - } - - @Nullable - @Override - public Object invoke(@NonNull MethodInvocation invocation) { - if (this.protocol != null) { - ReflectionUtils.invokeMethod( - setAcceptedSubProtocol, invocation.getArguments()[2], this.protocol); - } - return this.adapter; - } - } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java b/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java index abe8c4d095..67a6771779 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/JettyWebSocketTestServer.java @@ -27,6 +27,7 @@ import org.eclipse.jetty.server.Server; import org.eclipse.jetty.servlet.FilterHolder; import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletHolder; +import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.DispatcherServlet; @@ -57,6 +58,7 @@ public class JettyWebSocketTestServer implements WebSocketTestServer { ServletHolder servletHolder = new ServletHolder(new DispatcherServlet(wac)); this.contextHandler = new ServletContextHandler(); this.contextHandler.addServlet(servletHolder, "/"); + this.contextHandler.addServletContainerInitializer(new JettyWebSocketServletContainerInitializer()); for (Filter filter : filters) { this.contextHandler.addFilter(new FilterHolder(filter), "/*", getDispatcherTypes()); }