ws.rs 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. use crate::{connect::WsConnection, errors::WsError, WsMessage};
  2. use flowy_net::errors::ServerError;
  3. use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
  4. use futures_core::{ready, Stream};
  5. use crate::connect::Retry;
  6. use futures_core::future::BoxFuture;
  7. use pin_project::pin_project;
  8. use std::{
  9. collections::HashMap,
  10. future::Future,
  11. marker::PhantomData,
  12. pin::Pin,
  13. sync::Arc,
  14. task::{Context, Poll},
  15. };
  16. use tokio::task::JoinHandle;
  17. use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
  18. pub type MsgReceiver = UnboundedReceiver<Message>;
  19. pub type MsgSender = UnboundedSender<Message>;
  20. pub trait WsMessageHandler: Sync + Send + 'static {
  21. fn source(&self) -> String;
  22. fn receive_message(&self, msg: WsMessage);
  23. }
  24. pub struct WsController {
  25. sender: Option<Arc<WsSender>>,
  26. handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
  27. addr: Option<String>,
  28. }
  29. impl WsController {
  30. pub fn new() -> Self {
  31. let controller = Self {
  32. sender: None,
  33. handlers: HashMap::new(),
  34. addr: None,
  35. };
  36. controller
  37. }
  38. pub fn add_handler(&mut self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
  39. let source = handler.source();
  40. if self.handlers.contains_key(&source) {
  41. return Err(WsError::duplicate_source());
  42. }
  43. self.handlers.insert(source, handler);
  44. Ok(())
  45. }
  46. pub fn connect(&mut self, addr: String) -> Result<JoinHandle<()>, ServerError> { self._connect(addr.clone(), None) }
  47. pub fn connect_with_retry<F>(&mut self, addr: String, retry: Retry<F>) -> Result<JoinHandle<()>, ServerError>
  48. where
  49. F: Fn(&str) + Send + Sync + 'static,
  50. {
  51. self._connect(addr, Some(Box::pin(async { retry.await })))
  52. }
  53. fn _connect(&mut self, addr: String, retry: Option<BoxFuture<'static, ()>>) -> Result<JoinHandle<()>, ServerError> {
  54. log::debug!("๐Ÿด ws connect: {}", &addr);
  55. let (connection, handlers) = self.make_connect(addr.clone());
  56. Ok(tokio::spawn(async move {
  57. match connection.await {
  58. Ok(stream) => {
  59. tokio::select! {
  60. result = stream => {
  61. match result {
  62. Ok(_) => {},
  63. Err(e) => {
  64. // TODO: retry?
  65. log::error!("ws stream error {:?}", e);
  66. }
  67. }
  68. },
  69. result = handlers => log::debug!("handlers completed {:?}", result),
  70. };
  71. },
  72. Err(e) => match retry {
  73. None => log::error!("ws connect {} failed {:?}", addr, e),
  74. Some(retry) => {
  75. tokio::spawn(retry);
  76. },
  77. },
  78. }
  79. }))
  80. }
  81. fn make_connect(&mut self, addr: String) -> (WsConnection, WsHandlers) {
  82. // Stream User
  83. // โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
  84. // โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
  85. // โ”‚Serverโ”‚โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ–ถโ”‚ ws_read โ”‚โ”€โ”€โ”ผโ”€โ”€โ”€โ–ถโ”‚ msg_tx โ”‚โ”€โ”€โ”€โ”ผโ”€โ–ถโ”‚ msg_rx โ”‚ โ”‚
  86. // โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
  87. // โ–ฒ โ”‚ โ”‚ โ”‚ โ”‚
  88. // โ”‚ โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
  89. // โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”‚ws_write โ”‚โ—€โ”€โ”ผโ”€โ”€โ”€โ”€โ”‚ ws_rx โ”‚โ—€โ”€โ”€โ”ผโ”€โ”€โ”‚ ws_tx โ”‚ โ”‚
  90. // โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
  91. // โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
  92. let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
  93. let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
  94. let handlers = self.handlers.clone();
  95. self.sender = Some(Arc::new(WsSender::new(ws_tx)));
  96. self.addr = Some(addr.clone());
  97. (WsConnection::new(msg_tx, ws_rx, addr), WsHandlers::new(handlers, msg_rx))
  98. }
  99. pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
  100. match self.sender.as_ref() {
  101. None => Err(WsError::internal().context("Should call make_connect first")),
  102. Some(sender) => sender.send(msg.into()),
  103. }
  104. }
  105. }
  106. #[pin_project]
  107. pub struct WsHandlers {
  108. #[pin]
  109. msg_rx: MsgReceiver,
  110. handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
  111. }
  112. impl WsHandlers {
  113. fn new(handlers: HashMap<String, Arc<dyn WsMessageHandler>>, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
  114. }
  115. impl Future for WsHandlers {
  116. type Output = ();
  117. fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  118. loop {
  119. match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
  120. None => {
  121. return Poll::Ready(());
  122. },
  123. Some(message) => {
  124. log::debug!("๐Ÿด ws handler receive message");
  125. let message = WsMessage::from(message);
  126. match self.handlers.get(&message.source) {
  127. None => log::error!("Can't find any handler for message: {:?}", message),
  128. Some(handler) => handler.receive_message(message.clone()),
  129. }
  130. },
  131. }
  132. }
  133. }
  134. }
  135. struct WsSender {
  136. ws_tx: MsgSender,
  137. }
  138. impl WsSender {
  139. pub fn new(ws_tx: MsgSender) -> Self { Self { ws_tx } }
  140. pub fn send(&self, msg: WsMessage) -> Result<(), WsError> {
  141. let _ = self.ws_tx.unbounded_send(msg.into()).map_err(|e| WsError::internal().context(e))?;
  142. Ok(())
  143. }
  144. }
  145. #[cfg(test)]
  146. mod tests {
  147. use super::WsController;
  148. #[tokio::test]
  149. async fn connect() {
  150. std::env::set_var("RUST_LOG", "Debug");
  151. env_logger::init();
  152. let mut controller = WsController::new();
  153. let addr = format!("{}/123", flowy_net::config::WS_ADDR.as_str());
  154. let (a, b) = controller.make_connect(addr);
  155. tokio::select! {
  156. r = a => println!("write completed {:?}", r),
  157. _ = b => println!("read completed"),
  158. };
  159. }
  160. }