From c67b69433907bfc7108d2a85e849b0cdd12468d0 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 28 May 2013 10:06:53 -0400 Subject: [PATCH] Add STOMP service that relays messages to STOMP broker --- .../web/stomp/StompMessage.java | 2 +- .../ReactorServerStompMessageProcessor.java | 57 +++-- .../server/RelayStompReactorService.java | 233 ++++++++++++++++++ spring-websocket/src/test/resources/log4j.xml | 4 + 4 files changed, 277 insertions(+), 19 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/web/stomp/server/RelayStompReactorService.java diff --git a/spring-websocket/src/main/java/org/springframework/web/stomp/StompMessage.java b/spring-websocket/src/main/java/org/springframework/web/stomp/StompMessage.java index 421dde2a17..96a85e4ccc 100644 --- a/spring-websocket/src/main/java/org/springframework/web/stomp/StompMessage.java +++ b/spring-websocket/src/main/java/org/springframework/web/stomp/StompMessage.java @@ -72,7 +72,7 @@ public class StompMessage { @Override public String toString() { - return "StompMessage [headers=" + this.headers + ", payload=" + new String(this.payload) + "]"; + return "StompMessage [" + command + ", headers=" + this.headers + ", payload=" + new String(this.payload) + "]"; } } diff --git a/spring-websocket/src/main/java/org/springframework/web/stomp/server/ReactorServerStompMessageProcessor.java b/spring-websocket/src/main/java/org/springframework/web/stomp/server/ReactorServerStompMessageProcessor.java index 3e9dd5c8a6..cabc8ec7f5 100644 --- a/spring-websocket/src/main/java/org/springframework/web/stomp/server/ReactorServerStompMessageProcessor.java +++ b/spring-websocket/src/main/java/org/springframework/web/stomp/server/ReactorServerStompMessageProcessor.java @@ -42,7 +42,6 @@ import reactor.fn.Registration; * @author Gary Russell * @author Rossen Stoyanchev * @since 4.0 - * */ public class ReactorServerStompMessageProcessor implements StompMessageProcessor { @@ -51,7 +50,7 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor private final Reactor reactor; - private Map>> subscriptionsBySession = new ConcurrentHashMap>>(); + private Map>> registrationsBySession = new ConcurrentHashMap>>(); public ReactorServerStompMessageProcessor(Reactor reactor) { @@ -59,7 +58,6 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor } public void processMessage(StompSession session, StompMessage message) { - try { StompCommand command = message.getCommand(); if (StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command)) { @@ -97,8 +95,8 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor private void handleError(final StompSession session, Throwable t) { logger.error("Terminating STOMP session due to failure to send message: ", t); sendErrorMessage(session, t.getMessage()); - if (removeSubscriptions(session.getId())) { - // TODO: send error event and including exception info + if (removeSubscriptions(session)) { + // TODO: send error event including exception info } } @@ -114,7 +112,7 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor } } - protected void connect(StompSession session, StompMessage stompMessage) throws IOException { + protected void connect(final StompSession session, StompMessage stompMessage) throws IOException { StompHeaders headers = new StompHeaders(); Set acceptVersions = stompMessage.getHeaders().getAcceptVersion(); @@ -137,7 +135,31 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor session.sendMessage(new StompMessage(StompCommand.CONNECTED, headers)); - this.reactor.notify(StompCommand.CONNECT, Fn.event(stompMessage)); + String replyToKey = "relay-message" + session.getId(); + + Registration registration = this.reactor.on(Fn.$(replyToKey), new Consumer>() { + @Override + public void accept(Event event) { + try { + StompMessage message = event.getData(); + if (StompCommand.CONNECTED.equals(message.getCommand())) { + // TODO: skip for now (we already sent CONNECTED) + return; + } + if (logger.isTraceEnabled()) { + logger.trace("Relaying back to client: " + message); + } + session.sendMessage(message); + } + catch (Throwable t) { + handleError(session, t); + } + } + }); + + addRegistration(session.getId(), registration); + + this.reactor.notify(StompCommand.CONNECT, Fn.event(stompMessage, replyToKey)); } protected void subscribe(final StompSession session, StompMessage stompMessage) { @@ -165,7 +187,7 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor } }); - addSubscription(session.getId(), registration); + addRegistration(session.getId(), registration); this.reactor.notify(StompCommand.SUBSCRIBE, Fn.event(stompMessage, replyToKey)); @@ -174,11 +196,11 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor // http://stomp.github.io/stomp-specification-1.2.html#SUBSCRIBE } - private void addSubscription(String sessionId, Registration registration) { - List> list = this.subscriptionsBySession.get(sessionId); + private void addRegistration(String sessionId, Registration registration) { + List> list = this.registrationsBySession.get(sessionId); if (list == null) { list = new ArrayList>(); - this.subscriptionsBySession.put(sessionId, list); + this.registrationsBySession.put(sessionId, list); } list.add(registration); } @@ -192,13 +214,13 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor } protected void disconnect(StompSession session, StompMessage stompMessage) { - String sessionId = session.getId(); - removeSubscriptions(sessionId); + removeSubscriptions(session); this.reactor.notify(StompCommand.DISCONNECT, Fn.event(stompMessage)); } - private boolean removeSubscriptions(String sessionId) { - List> registrations = this.subscriptionsBySession.remove(sessionId); + private boolean removeSubscriptions(StompSession session) { + String sessionId = session.getId(); + List> registrations = this.registrationsBySession.remove(sessionId); if (CollectionUtils.isEmpty(registrations)) { return false; } @@ -213,9 +235,8 @@ public class ReactorServerStompMessageProcessor implements StompMessageProcessor @Override public void processConnectionClosed(StompSession session) { - if (removeSubscriptions(session.getId())) { - // TODO: this implies abnormal closure from the underlying transport (no DISCONNECT) .. send an error event - } + removeSubscriptions(session); + this.reactor.notify("CONNECTION_CLOSED", Fn.event(session.getId())); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/stomp/server/RelayStompReactorService.java b/spring-websocket/src/main/java/org/springframework/web/stomp/server/RelayStompReactorService.java new file mode 100644 index 0000000000..97274b59cf --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/stomp/server/RelayStompReactorService.java @@ -0,0 +1,233 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.stomp.server; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.Socket; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import javax.net.SocketFactory; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.core.task.TaskExecutor; +import org.springframework.web.stomp.StompCommand; +import org.springframework.web.stomp.StompHeaders; +import org.springframework.web.stomp.StompMessage; +import org.springframework.web.stomp.support.StompMessageConverter; + +import reactor.Fn; +import reactor.core.Reactor; +import reactor.fn.Consumer; +import reactor.fn.Event; +import reactor.util.Assert; + + +/** + * @author Rossen Stoyanchev + * @since 4.0 + */ +public class RelayStompReactorService { + + private static final Log logger = LogFactory.getLog(RelayStompReactorService.class); + + + private final Reactor reactor; + + private Map relaySessions = new ConcurrentHashMap(); + + private final StompMessageConverter converter = new StompMessageConverter(); + + private final TaskExecutor taskExecutor; + + + public RelayStompReactorService(Reactor reactor, TaskExecutor executor) { + this.reactor = reactor; + this.taskExecutor = executor; // For now, a naively way to manage socket reading + + this.reactor.on(Fn.$(StompCommand.CONNECT), new ConnectConsumer()); + this.reactor.on(Fn.$(StompCommand.SUBSCRIBE), new RelayConsumer()); + this.reactor.on(Fn.$(StompCommand.SEND), new RelayConsumer()); + this.reactor.on(Fn.$(StompCommand.DISCONNECT), new RelayConsumer()); + + this.reactor.on(Fn.$("CONNECTION_CLOSED"), new Consumer>() { + @Override + public void accept(Event event) { + if (logger.isDebugEnabled()) { + logger.debug("CONNECTION_CLOSED, STOMP session=" + event.getData() + ". Clearing relay session"); + } + clearRelaySession(event.getData()); + } + }); + } + + private void relayStompMessage(RelaySession session, StompMessage stompMessage) throws Exception { + if (logger.isTraceEnabled()) { + logger.trace("Forwarding: " + stompMessage); + } + byte[] bytes = converter.fromStompMessage(stompMessage); + session.getOutputStream().write(bytes); + session.getOutputStream().flush(); + } + + private RelaySession getRelaySession(String stompSessionId) { + RelaySession session = RelayStompReactorService.this.relaySessions.get(stompSessionId); + Assert.notNull(session, "RelaySession not found"); + return session; + } + + private void clearRelaySession(String stompSessionId) { + RelaySession relaySession = this.relaySessions.remove(stompSessionId); + if (relaySession != null) { + // TODO: raise failure event so client session can be closed + try { + relaySession.getSocket().close(); + } + catch (IOException e) { + // ignore + } + } + } + + + private final class ConnectConsumer implements Consumer> { + + @Override + public void accept(Event event) { + + StompMessage stompMessage = event.getData(); + final Object replyTo = event.getReplyTo(); + final String stompSessionId = stompMessage.getStompSessionId(); + + final RelaySession session = new RelaySession(); + relaySessions.put(stompSessionId, session); + + try { + Socket socket = SocketFactory.getDefault().createSocket("127.0.0.1", 61613); + session.setSocket(socket); + + relayStompMessage(session, stompMessage); + + taskExecutor.execute(new RelayReadTask(stompSessionId, replyTo, session)); + } + catch (Throwable t) { + t.printStackTrace(); + clearRelaySession(stompSessionId); + } + } + } + + private final static class RelaySession { + + private Socket socket; + + private InputStream inputStream; + + private OutputStream outputStream; + + + public void setSocket(Socket socket) throws IOException { + this.socket = socket; + this.inputStream = new BufferedInputStream(socket.getInputStream()); + this.outputStream = new BufferedOutputStream(socket.getOutputStream()); + } + + public Socket getSocket() { + return this.socket; + } + + public InputStream getInputStream() { + return this.inputStream; + } + + public OutputStream getOutputStream() { + return this.outputStream; + } + } + + private final class RelayReadTask implements Runnable { + + private final String stompSessionId; + private final Object replyTo; + private final RelaySession session; + + private RelayReadTask(String stompSessionId, Object replyTo, RelaySession session) { + this.stompSessionId = stompSessionId; + this.replyTo = replyTo; + this.session = session; + } + + @Override + public void run() { + try { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + while (!session.getSocket().isClosed()) { + int b = session.getInputStream().read(); + if (b == -1) { + break; + } + else if (b == 0x00) { + byte[] bytes = out.toByteArray(); + StompMessage message = RelayStompReactorService.this.converter.toStompMessage(bytes); + RelayStompReactorService.this.reactor.notify(replyTo, Fn.event(message)); + out.reset(); + } + else { + out.write(b); + } + } + logger.debug("Socket closed, STOMP session=" + stompSessionId); + sendLostConnectionErrorMessage(); + } + catch (IOException e) { + e.printStackTrace(); + clearRelaySession(stompSessionId); + } + } + + private void sendLostConnectionErrorMessage() { + StompHeaders headers = new StompHeaders(); + headers.setMessage("Lost connection"); + StompMessage errorMessage = new StompMessage(StompCommand.ERROR, headers); + RelayStompReactorService.this.reactor.notify(replyTo, Fn.event(errorMessage)); + } + } + + private class RelayConsumer implements Consumer> { + + @Override + public void accept(Event event) { + StompMessage stompMessage = event.getData(); + RelaySession session = getRelaySession(stompMessage.getStompSessionId()); + try { + relayStompMessage(session, stompMessage); + } + catch (Exception e) { + // TODO Auto-generated catch block + e.printStackTrace(); + clearRelaySession(stompMessage.getStompSessionId()); + } + } + } + +} diff --git a/spring-websocket/src/test/resources/log4j.xml b/spring-websocket/src/test/resources/log4j.xml index 8fa59bf2f3..8b7bac45ad 100644 --- a/spring-websocket/src/test/resources/log4j.xml +++ b/spring-websocket/src/test/resources/log4j.xml @@ -22,6 +22,10 @@ + + + +