From 52799c0e3d14662f01e1e381ddbbd57981d89d4e Mon Sep 17 00:00:00 2001 From: Juergen Hoeller Date: Fri, 9 Dec 2016 15:21:31 +0100 Subject: [PATCH] Revised Jetty 9.3 vs 9.4 differentiation Issue: SPR-14940 --- .../adapter/AbstractWebSocketSession.java | 7 +- .../adapter/jetty/JettyWebSocketSession.java | 90 ++++----- .../jetty/JettyRequestUpgradeStrategy.java | 175 +++++++++++------- 3 files changed, 155 insertions(+), 117 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSession.java index 469b0150ca..a5e78b6ae0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/AbstractWebSocketSession.java @@ -42,15 +42,13 @@ public abstract class AbstractWebSocketSession implements NativeWebSocketSess protected static final Log logger = LogFactory.getLog(NativeWebSocketSession.class); + private final Map attributes = new ConcurrentHashMap<>(); private T nativeSession; - private final Map attributes = new ConcurrentHashMap<>(); - /** * Create a new instance and associate the given attributes with it. - * * @param attributes attributes from the HTTP handshake to associate with the WebSocket * session; the provided attributes are copied, the original map is not used. */ @@ -83,7 +81,7 @@ public abstract class AbstractWebSocketSession implements NativeWebSocketSess } public void initializeNativeSession(T session) { - Assert.notNull(session, "session must not be null"); + Assert.notNull(session, "WebSocket session must not be null"); this.nativeSession = session; } @@ -125,6 +123,7 @@ public abstract class AbstractWebSocketSession implements NativeWebSocketSess protected abstract void sendPongMessage(PongMessage message) throws IOException; + @Override public final void close() throws IOException { close(CloseStatus.NORMAL); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java index bc6030f8e5..4062db5af1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketSession.java @@ -45,49 +45,38 @@ import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.AbstractWebSocketSession; /** - * A {@link WebSocketSession} for use with the Jetty 9 WebSocket API. + * A {@link WebSocketSession} for use with the Jetty 9.3/9.4 WebSocket API. * * @author Phillip Webb * @author Rossen Stoyanchev * @author Brian Clozel + * @author Juergen Hoeller * @since 4.0 */ public class JettyWebSocketSession extends AbstractWebSocketSession { // As of Jetty 9.4, UpgradeRequest and UpgradeResponse are interfaces instead of classes - private static final boolean isJetty94; + private static final boolean directInterfaceCalls; private static Method getUpgradeRequest; private static Method getUpgradeResponse; private static Method getRequestURI; private static Method getHeaders; + private static Method getUserPrincipal; private static Method getAcceptedSubProtocol; private static Method getExtensions; - private static Method getUserPrincipal; - - private String id; - - private URI uri; - - private HttpHeaders headers; - - private String acceptedProtocol; - - private List extensions; - - private Principal user; static { - isJetty94 = UpgradeRequest.class.isInterface(); - if (!isJetty94) { + directInterfaceCalls = UpgradeRequest.class.isInterface(); + if (!directInterfaceCalls) { try { getUpgradeRequest = Session.class.getMethod("getUpgradeRequest"); getUpgradeResponse = Session.class.getMethod("getUpgradeResponse"); getRequestURI = UpgradeRequest.class.getMethod("getRequestURI"); getHeaders = UpgradeRequest.class.getMethod("getHeaders"); + getUserPrincipal = UpgradeRequest.class.getMethod("getUserPrincipal"); getAcceptedSubProtocol = UpgradeResponse.class.getMethod("getAcceptedSubProtocol"); getExtensions = UpgradeResponse.class.getMethod("getExtensions"); - getUserPrincipal = UpgradeRequest.class.getMethod("getUserPrincipal"); } catch (NoSuchMethodException ex) { throw new IllegalStateException("Incompatible Jetty API", ex); @@ -95,9 +84,22 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { } } + + private String id; + + private URI uri; + + private HttpHeaders headers; + + private String acceptedProtocol; + + private List extensions; + + private Principal user; + + /** * Create a new {@link JettyWebSocketSession} instance. - * * @param attributes attributes from the HTTP handshake to associate with the WebSocket session */ public JettyWebSocketSession(Map attributes) { @@ -106,11 +108,10 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { /** * Create a new {@link JettyWebSocketSession} instance associated with the given user. - * * @param attributes attributes from the HTTP handshake to associate with the WebSocket * session; the provided attributes are copied, the original map is not used. - * @param user the user associated with the session; if {@code null} we'll fallback on the user - * available via {@link org.eclipse.jetty.websocket.api.Session#getUpgradeRequest()} + * @param user the user associated with the session; if {@code null} we'll fallback on the + * user available via {@link org.eclipse.jetty.websocket.api.Session#getUpgradeRequest()} */ public JettyWebSocketSession(Map attributes, Principal user) { super(attributes); @@ -191,36 +192,32 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { @Override public boolean isOpen() { - return ((getNativeSession() != null) && getNativeSession().isOpen()); + return (getNativeSession() != null && getNativeSession().isOpen()); } + @Override public void initializeNativeSession(Session session) { super.initializeNativeSession(session); - if (isJetty94) { - initializeJetty94Session(session); + if (directInterfaceCalls) { + initializeJettySessionDirectly(session); } else { - initializeJettySession(session); + initializeJettySessionReflectively(session); } } - @SuppressWarnings("unchecked") - private void initializeJettySession(Session session) { - - Object request = ReflectionUtils.invokeMethod(getUpgradeRequest, session); - Object response = ReflectionUtils.invokeMethod(getUpgradeResponse, session); - + private void initializeJettySessionDirectly(Session session) { this.id = ObjectUtils.getIdentityHexString(getNativeSession()); - this.uri = (URI) ReflectionUtils.invokeMethod(getRequestURI, request); + this.uri = session.getUpgradeRequest().getRequestURI(); this.headers = new HttpHeaders(); - this.headers.putAll((Map>) ReflectionUtils.invokeMethod(getHeaders, request)); + this.headers.putAll(session.getUpgradeRequest().getHeaders()); this.headers = HttpHeaders.readOnlyHttpHeaders(headers); - this.acceptedProtocol = (String) ReflectionUtils.invokeMethod(getAcceptedSubProtocol, response); + this.acceptedProtocol = session.getUpgradeResponse().getAcceptedSubProtocol(); - List source = (List) ReflectionUtils.invokeMethod(getExtensions, response); + List source = session.getUpgradeResponse().getExtensions(); if (source != null) { this.extensions = new ArrayList<>(source.size()); for (ExtensionConfig ec : source) { @@ -232,21 +229,25 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { } if (this.user == null) { - this.user = (Principal) ReflectionUtils.invokeMethod(getUserPrincipal, request); + this.user = session.getUpgradeRequest().getUserPrincipal(); } } - private void initializeJetty94Session(Session session) { + @SuppressWarnings("unchecked") + private void initializeJettySessionReflectively(Session session) { + Object request = ReflectionUtils.invokeMethod(getUpgradeRequest, session); + Object response = ReflectionUtils.invokeMethod(getUpgradeResponse, session); + this.id = ObjectUtils.getIdentityHexString(getNativeSession()); - this.uri = session.getUpgradeRequest().getRequestURI(); + this.uri = (URI) ReflectionUtils.invokeMethod(getRequestURI, request); this.headers = new HttpHeaders(); - this.headers.putAll(session.getUpgradeRequest().getHeaders()); + this.headers.putAll((Map>) ReflectionUtils.invokeMethod(getHeaders, request)); this.headers = HttpHeaders.readOnlyHttpHeaders(headers); - this.acceptedProtocol = session.getUpgradeResponse().getAcceptedSubProtocol(); + this.acceptedProtocol = (String) ReflectionUtils.invokeMethod(getAcceptedSubProtocol, response); - List source = session.getUpgradeResponse().getExtensions(); + List source = (List) ReflectionUtils.invokeMethod(getExtensions, response); if (source != null) { this.extensions = new ArrayList<>(source.size()); for (ExtensionConfig ec : source) { @@ -258,10 +259,11 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { } if (this.user == null) { - this.user = session.getUpgradeRequest().getUserPrincipal(); + this.user = (Principal) ReflectionUtils.invokeMethod(getUserPrincipal, request); } } + @Override protected void sendTextMessage(TextMessage message) throws IOException { getRemoteEndpoint().sendString(message.getPayload()); @@ -287,7 +289,7 @@ public class JettyWebSocketSession extends AbstractWebSocketSession { return getNativeSession().getRemote(); } catch (WebSocketException ex) { - throw new IOException("Unable to obtain RemoteEndpoint in session=" + getId(), ex); + throw new IOException("Unable to obtain RemoteEndpoint in session " + getId(), ex); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java index 9bf3260f4e..937c5bdd7b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java @@ -21,7 +21,6 @@ import java.security.Principal; import java.util.ArrayList; import java.util.List; import java.util.Map; - import javax.servlet.ServletContext; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -60,75 +59,59 @@ import org.springframework.web.socket.server.RequestUpgradeStrategy; * @author Phillip Webb * @author Rossen Stoyanchev * @author Brian Clozel + * @author Juergen Hoeller * @since 4.0 */ -public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Lifecycle, ServletContextAware { +public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, ServletContextAware, Lifecycle { private static final ThreadLocal wsContainerHolder = new NamedThreadLocal<>("WebSocket Handler Container"); - // Actually 9.3.15+ - private static boolean isJetty94 = ClassUtils.hasConstructor(WebSocketServerFactory.class, ServletContext.class); + private final WebSocketServerFactoryAdapter factoryAdapter = + (ClassUtils.hasConstructor(WebSocketServerFactory.class, ServletContext.class) ? + new ModernJettyWebSocketServerFactoryAdapter() : new LegacyJettyWebSocketServerFactoryAdapter()); - private WebSocketServerFactoryAdapter factoryAdapter; + private ServletContext servletContext; - private volatile List supportedExtensions; + private volatile boolean running = false; - protected ServletContext servletContext; + private volatile List supportedExtensions; - private volatile boolean running = false; /** * Default constructor that creates {@link WebSocketServerFactory} through * its default constructor thus using a default {@link WebSocketPolicy}. */ public JettyRequestUpgradeStrategy() { - this(WebSocketPolicy.newServerPolicy()); + this.factoryAdapter.setPolicy(WebSocketPolicy.newServerPolicy()); } /** - * A constructor accepting a {@link WebSocketPolicy} - * to be used when creating the {@link WebSocketServerFactory} instance. - * @since 4.3 + * A constructor accepting a {@link WebSocketPolicy} to be used when + * creating the {@link WebSocketServerFactory} instance. + * @param policy the policy to use + * @since 4.3.5 */ - public JettyRequestUpgradeStrategy(WebSocketPolicy webSocketPolicy) { - this.factoryAdapter = isJetty94 ? new Jetty94WebSocketServerFactoryAdapter() - : new JettyWebSocketServerFactoryAdapter(); - this.factoryAdapter.setWebSocketPolicy(webSocketPolicy); + public JettyRequestUpgradeStrategy(WebSocketPolicy policy) { + Assert.notNull(policy, "WebSocketPolicy must not be null"); + this.factoryAdapter.setPolicy(policy); } - @Override - public String[] getSupportedVersions() { - return new String[] {String.valueOf(HandshakeRFC6455.VERSION)}; + /** + * A constructor accepting a {@link WebSocketServerFactory}. + * @param factory the pre-configured factory to use + */ + public JettyRequestUpgradeStrategy(WebSocketServerFactory factory) { + Assert.notNull(factory, "WebSocketServerFactory must not be null"); + this.factoryAdapter.setFactory(factory); } - @Override - public List getSupportedExtensions(ServerHttpRequest request) { - if (this.supportedExtensions == null) { - this.supportedExtensions = getWebSocketExtensions(); - } - return this.supportedExtensions; - } - - private List getWebSocketExtensions() { - List result = new ArrayList<>(); - for (String name : this.factoryAdapter.getFactory().getExtensionFactory().getExtensionNames()) { - result.add(new WebSocketExtension(name)); - } - return result; - } @Override public void setServletContext(ServletContext servletContext) { this.servletContext = servletContext; } - @Override - public boolean isRunning() { - return this.running; - } - - @Override public void start() { if (!isRunning()) { @@ -136,7 +119,7 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life try { this.factoryAdapter.start(); } - catch (Exception ex) { + catch (Throwable ex) { throw new IllegalStateException("Unable to start Jetty WebSocketServerFactory", ex); } } @@ -149,12 +132,39 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life this.running = false; this.factoryAdapter.stop(); } - catch (Exception ex) { + catch (Throwable ex) { throw new IllegalStateException("Unable to stop Jetty WebSocketServerFactory", ex); } } } + @Override + public boolean isRunning() { + return this.running; + } + + + @Override + public String[] getSupportedVersions() { + return new String[] { String.valueOf(HandshakeRFC6455.VERSION) }; + } + + @Override + public List getSupportedExtensions(ServerHttpRequest request) { + if (this.supportedExtensions == null) { + this.supportedExtensions = buildWebSocketExtensions(); + } + return this.supportedExtensions; + } + + private List buildWebSocketExtensions() { + List result = new ArrayList<>(); + for (String name : this.factoryAdapter.getFactory().getExtensionFactory().getExtensionNames()) { + result.add(new WebSocketExtension(name)); + } + return result; + } + @Override public void upgrade(ServerHttpRequest request, ServerHttpResponse response, String selectedProtocol, List selectedExtensions, Principal user, @@ -197,7 +207,9 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life private final List extensionConfigs; - public WebSocketHandlerContainer(JettyWebSocketHandlerAdapter handler, String protocol, List extensions) { + public WebSocketHandlerContainer( + JettyWebSocketHandlerAdapter handler, String protocol, List extensions) { + this.handler = handler; this.selectedProtocol = protocol; if (CollectionUtils.isEmpty(extensions)) { @@ -224,21 +236,29 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life } } + private static abstract class WebSocketServerFactoryAdapter { - protected WebSocketServerFactory factory; + private WebSocketPolicy policy; - protected WebSocketPolicy webSocketPolicy; + private WebSocketServerFactory factory; - public WebSocketServerFactory getFactory() { - return factory; + public void setPolicy(WebSocketPolicy policy) { + this.policy = policy; } - public void setWebSocketPolicy(WebSocketPolicy webSocketPolicy) { - this.webSocketPolicy = webSocketPolicy; + public void setFactory(WebSocketServerFactory factory) { + this.factory = factory; } - protected void configureFactory() { + public WebSocketServerFactory getFactory() { + return this.factory; + } + + public void start() throws Exception { + if (this.factory == null) { + this.factory = createFactory(this.policy); + } this.factory.setCreator(new WebSocketCreator() { @Override public Object createWebSocket(ServletUpgradeRequest request, ServletUpgradeResponse response) { @@ -249,43 +269,60 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Life return container.getHandler(); } }); + startFactory(this.factory); } - abstract void start() throws Exception; + public void stop() throws Exception { + if (this.factory != null) { + stopFactory(this.factory); + } + } + + protected abstract WebSocketServerFactory createFactory(WebSocketPolicy policy) throws Exception; - abstract void stop() throws Exception; + protected abstract void startFactory(WebSocketServerFactory factory) throws Exception; + + protected abstract void stopFactory(WebSocketServerFactory factory) throws Exception; } - private class JettyWebSocketServerFactoryAdapter extends WebSocketServerFactoryAdapter { + + // Jetty 9.3.15+ + private class ModernJettyWebSocketServerFactoryAdapter extends WebSocketServerFactoryAdapter { @Override - void start() throws Exception { - this.factory = WebSocketServerFactory.class.getConstructor(WebSocketPolicy.class) - .newInstance(this.webSocketPolicy); - configureFactory(); - WebSocketServerFactory.class.getMethod("init", ServletContext.class) - .invoke(this.factory, servletContext); + protected WebSocketServerFactory createFactory(WebSocketPolicy policy) throws Exception { + servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory()); + return new WebSocketServerFactory(servletContext, policy); } @Override - void stop() throws Exception { - WebSocketServerFactory.class.getMethod("cleanup").invoke(this.factory); + protected void startFactory(WebSocketServerFactory factory) throws Exception { + factory.start(); + } + + @Override + protected void stopFactory(WebSocketServerFactory factory) throws Exception { + factory.stop(); } } - private class Jetty94WebSocketServerFactoryAdapter extends WebSocketServerFactoryAdapter { + + // Jetty <9.3.15 + private class LegacyJettyWebSocketServerFactoryAdapter extends WebSocketServerFactoryAdapter { @Override - void start() throws Exception { - servletContext.setAttribute(DecoratedObjectFactory.ATTR, new DecoratedObjectFactory()); - this.factory = new WebSocketServerFactory(servletContext, this.webSocketPolicy); - configureFactory(); - this.factory.start(); + protected WebSocketServerFactory createFactory(WebSocketPolicy policy) throws Exception { + return WebSocketServerFactory.class.getConstructor(WebSocketPolicy.class).newInstance(policy); + } + + @Override + protected void startFactory(WebSocketServerFactory factory) throws Exception { + WebSocketServerFactory.class.getMethod("init", ServletContext.class).invoke(factory, servletContext); } @Override - void stop() throws Exception { - this.factory.stop(); + protected void stopFactory(WebSocketServerFactory factory) throws Exception { + WebSocketServerFactory.class.getMethod("cleanup").invoke(factory); } }