ws.rs 9.2 KB


  1. use crate::{connect::WsConnection, errors::WsError, WsMessage};
  2. use flowy_net::errors::ServerError;
  3. use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
  4. use futures_core::{ready, Stream};
  5. use crate::connect::Retry;
  6. use bytes::Buf;
  7. use futures_core::future::BoxFuture;
  8. use pin_project::pin_project;
  9. use std::{
  10. collections::HashMap,
  11. future::Future,
  12. marker::PhantomData,
  13. pin::Pin,
  14. sync::Arc,
  15. task::{Context, Poll},
  16. };
  17. use tokio::{sync::RwLock, task::JoinHandle};
  18. use tokio_tungstenite::{
  19. tungstenite::{
  20. protocol::{frame::coding::CloseCode, CloseFrame},
  21. Message,
  22. },
  23. MaybeTlsStream,
  24. WebSocketStream,
  25. };
  26. pub type MsgReceiver = UnboundedReceiver<Message>;
  27. pub type MsgSender = UnboundedSender<Message>;
  28. pub trait WsMessageHandler: Sync + Send + 'static {
  29. fn source(&self) -> String;
  30. fn receive_message(&self, msg: WsMessage);
  31. }
  32. type NotifyCallback = Arc<dyn Fn(&WsState) + Send + Sync + 'static>;
  33. struct WsStateNotify {
  34. #[allow(dead_code)]
  35. state: WsState,
  36. callback: Option<NotifyCallback>,
  37. }
  38. impl WsStateNotify {
  39. fn update_state(&mut self, state: WsState) {
  40. if let Some(f) = &self.callback {
  41. f(&state);
  42. }
  43. self.state = state;
  44. }
  45. }
  46. pub enum WsState {
  47. Init,
  48. Connected(Arc<WsSender>),
  49. Disconnected(WsError),
  50. }
  51. pub struct WsController {
  52. handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
  53. state_notify: Arc<RwLock<WsStateNotify>>,
  54. addr: Option<String>,
  55. sender: Option<Arc<WsSender>>,
  56. }
  57. impl WsController {
  58. pub fn new() -> Self {
  59. let state_notify = Arc::new(RwLock::new(WsStateNotify {
  60. state: WsState::Init,
  61. callback: None,
  62. }));
  63. let controller = Self {
  64. handlers: HashMap::new(),
  65. state_notify,
  66. addr: None,
  67. sender: None,
  68. };
  69. controller
  70. }
  71. pub async fn state_callback<SC>(&self, callback: SC)
  72. where
  73. SC: Fn(&WsState) + Send + Sync + 'static,
  74. {
  75. (self.state_notify.write().await).callback = Some(Arc::new(callback));
  76. }
  77. pub fn add_handler(&mut self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
  78. let source = handler.source();
  79. if self.handlers.contains_key(&source) {
  80. return Err(WsError::duplicate_source());
  81. }
  82. self.handlers.insert(source, handler);
  83. Ok(())
  84. }
  85. pub fn connect(&mut self, addr: String) -> Result<JoinHandle<()>, ServerError> { self._connect(addr.clone(), None) }
  86. pub fn connect_with_retry<F>(&mut self, addr: String, retry: Retry<F>) -> Result<JoinHandle<()>, ServerError>
  87. where
  88. F: Fn(&str) + Send + Sync + 'static,
  89. {
  90. self._connect(addr, Some(Box::pin(async { retry.await })))
  91. }
  92. fn _connect(&mut self, addr: String, retry: Option<BoxFuture<'static, ()>>) -> Result<JoinHandle<()>, ServerError> {
  93. log::debug!("🐴 ws connect: {}", &addr);
  94. let (connection, handlers) = self.make_connect(addr.clone());
  95. let state_notify = self.state_notify.clone();
  96. let sender = self.sender.clone().expect("Sender should be not empty after calling make_connect");
  97. Ok(tokio::spawn(async move {
  98. match connection.await {
  99. Ok(stream) => {
  100. state_notify.write().await.update_state(WsState::Connected(sender));
  101. tokio::select! {
  102. result = stream => {
  103. match result {
  104. Ok(_) => {},
  105. Err(e) => {
  106. // TODO: retry?
  107. log::error!("ws stream error {:?}", e);
  108. state_notify.write().await.update_state(WsState::Disconnected(e));
  109. }
  110. }
  111. },
  112. result = handlers => log::debug!("handlers completed {:?}", result),
  113. };
  114. },
  115. Err(e) => {
  116. log::error!("ws connect {} failed {:?}", addr, e);
  117. state_notify.write().await.update_state(WsState::Disconnected(e));
  118. if let Some(retry) = retry {
  119. tokio::spawn(retry);
  120. }
  121. },
  122. }
  123. }))
  124. }
  125. fn make_connect(&mut self, addr: String) -> (WsConnection, WsHandlers) {
  126. // Stream User
  127. // ┌───────────────┐ ┌──────────────┐
  128. // ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
  129. // │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │ │
  130. // └──────┘ │ └─────────┘ │ └────────┘ │ └────────┘ │
  131. // ▲ │ │ │ │
  132. // │ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
  133. // └─────────┼──│ws_write │◀─┼────│ ws_rx │◀──┼──│ ws_tx │ │
  134. // │ └─────────┘ │ └────────┘ │ └────────┘ │
  135. // └───────────────┘ └──────────────┘
  136. let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
  137. let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
  138. let handlers = self.handlers.clone();
  139. self.sender = Some(Arc::new(WsSender { ws_tx }));
  140. self.addr = Some(addr.clone());
  141. (WsConnection::new(msg_tx, ws_rx, addr), WsHandlers::new(handlers, msg_rx))
  142. }
  143. }
  144. #[pin_project]
  145. pub struct WsHandlers {
  146. #[pin]
  147. msg_rx: MsgReceiver,
  148. handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
  149. }
  150. impl WsHandlers {
  151. fn new(handlers: HashMap<String, Arc<dyn WsMessageHandler>>, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
  152. }
  153. impl Future for WsHandlers {
  154. type Output = ();
  155. fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  156. loop {
  157. match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
  158. None => {
  159. return Poll::Ready(());
  160. },
  161. Some(message) => {
  162. log::debug!("🐴 ws handler receive message");
  163. let message = WsMessage::from(message);
  164. match self.handlers.get(&message.source) {
  165. None => log::error!("Can't find any handler for message: {:?}", message),
  166. Some(handler) => handler.receive_message(message.clone()),
  167. }
  168. },
  169. }
  170. }
  171. }
  172. }
  173. // impl WsSender for WsController {
  174. // fn send_msg(&self, msg: WsMessage) -> Result<(), WsError> {
  175. // match self.ws_tx.as_ref() {
  176. // None => Err(WsError::internal().context("Should call make_connect
  177. // first")), Some(sender) => {
  178. // let _ = sender.unbounded_send(msg.into()).map_err(|e|
  179. // WsError::internal().context(e))?; Ok(())
  180. // },
  181. // }
  182. // }
  183. // }
  184. #[derive(Debug, Clone)]
  185. pub struct WsSender {
  186. ws_tx: MsgSender,
  187. }
  188. impl WsSender {
  189. pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
  190. let msg = msg.into();
  191. let _ = self.ws_tx.unbounded_send(msg.into()).map_err(|e| WsError::internal().context(e))?;
  192. Ok(())
  193. }
  194. pub fn send_text(&self, source: &str, text: &str) -> Result<(), WsError> {
  195. let msg = WsMessage {
  196. source: source.to_string(),
  197. data: text.as_bytes().to_vec(),
  198. };
  199. self.send_msg(msg)
  200. }
  201. pub fn send_binary(&self, source: &str, bytes: Vec<u8>) -> Result<(), WsError> {
  202. let msg = WsMessage {
  203. source: source.to_string(),
  204. data: bytes,
  205. };
  206. self.send_msg(msg)
  207. }
  208. pub fn send_disconnect(&self, reason: &str) -> Result<(), WsError> {
  209. let frame = CloseFrame {
  210. code: CloseCode::Normal,
  211. reason: reason.to_owned().into(),
  212. };
  213. let msg = Message::Close(Some(frame));
  214. let _ = self.ws_tx.unbounded_send(msg).map_err(|e| WsError::internal().context(e))?;
  215. Ok(())
  216. }
  217. }
  218. #[cfg(test)]
  219. mod tests {
  220. use super::WsController;
  221. #[tokio::test]
  222. async fn connect() {
  223. std::env::set_var("RUST_LOG", "Debug");
  224. env_logger::init();
  225. let mut controller = WsController::new();
  226. let addr = format!("{}/123", flowy_net::config::WS_ADDR.as_str());
  227. let (a, b) = controller.make_connect(addr);
  228. tokio::select! {
  229. r = a => println!("write completed {:?}", r),
  230. _ = b => println!("read completed"),
  231. };
  232. }
  233. }