ws.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. #![allow(clippy::type_complexity)]
  2. use crate::{
  3. connect::{WsConnectionFuture, WsStream},
  4. errors::WsError,
  5. WsMessage,
  6. WsModule,
  7. };
  8. use backend_service::errors::ServerError;
  9. use bytes::Bytes;
  10. use dashmap::DashMap;
  11. use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
  12. use futures_core::{ready, Stream};
  13. use lib_infra::retry::{Action, FixedInterval, Retry};
  14. use parking_lot::RwLock;
  15. use pin_project::pin_project;
  16. use std::{
  17. convert::TryFrom,
  18. fmt::Formatter,
  19. future::Future,
  20. pin::Pin,
  21. sync::Arc,
  22. task::{Context, Poll},
  23. time::Duration,
  24. };
  25. use tokio::sync::{broadcast, oneshot};
  26. use tokio_tungstenite::tungstenite::{
  27. protocol::{frame::coding::CloseCode, CloseFrame},
  28. Message,
  29. };
  30. pub type MsgReceiver = UnboundedReceiver<Message>;
  31. pub type MsgSender = UnboundedSender<Message>;
  32. type Handlers = DashMap<WsModule, Arc<dyn WsMessageHandler>>;
  33. pub trait WsMessageHandler: Sync + Send + 'static {
  34. fn source(&self) -> WsModule;
  35. fn receive_message(&self, msg: WsMessage);
  36. }
  37. pub struct WsController {
  38. handlers: Handlers,
  39. state_notify: Arc<broadcast::Sender<WsConnectState>>,
  40. sender_ctrl: Arc<RwLock<WsSenderController>>,
  41. addr: Arc<RwLock<Option<String>>>,
  42. }
  43. impl std::default::Default for WsController {
  44. fn default() -> Self {
  45. let (state_notify, _) = broadcast::channel(16);
  46. Self {
  47. handlers: DashMap::new(),
  48. sender_ctrl: Arc::new(RwLock::new(WsSenderController::default())),
  49. state_notify: Arc::new(state_notify),
  50. addr: Arc::new(RwLock::new(None)),
  51. }
  52. }
  53. }
  54. impl WsController {
  55. pub fn new() -> Self { WsController::default() }
  56. pub fn add_handler(&self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
  57. let source = handler.source();
  58. if self.handlers.contains_key(&source) {
  59. log::error!("WsSource's {:?} is already registered", source);
  60. }
  61. self.handlers.insert(source, handler);
  62. Ok(())
  63. }
  64. pub async fn start(&self, addr: String) -> Result<(), ServerError> {
  65. *self.addr.write() = Some(addr.clone());
  66. let strategy = FixedInterval::from_millis(5000).take(3);
  67. self.connect(addr, strategy).await
  68. }
  69. async fn connect<T, I>(&self, addr: String, strategy: T) -> Result<(), ServerError>
  70. where
  71. T: IntoIterator<IntoIter = I, Item = Duration>,
  72. I: Iterator<Item = Duration> + Send + 'static,
  73. {
  74. let (ret, rx) = oneshot::channel::<Result<(), ServerError>>();
  75. *self.addr.write() = Some(addr.clone());
  76. let action = WsConnectAction {
  77. addr,
  78. handlers: self.handlers.clone(),
  79. };
  80. let retry = Retry::spawn(strategy, action);
  81. let sender_ctrl = self.sender_ctrl.clone();
  82. sender_ctrl.write().set_state(WsConnectState::Connecting);
  83. tokio::spawn(async move {
  84. match retry.await {
  85. Ok(result) => {
  86. let WsConnectResult {
  87. stream,
  88. handlers_fut,
  89. sender,
  90. } = result;
  91. sender_ctrl.write().set_sender(sender);
  92. sender_ctrl.write().set_state(WsConnectState::Connected);
  93. let _ = ret.send(Ok(()));
  94. spawn_stream_and_handlers(stream, handlers_fut, sender_ctrl.clone()).await;
  95. },
  96. Err(e) => {
  97. sender_ctrl.write().set_error(e.clone());
  98. let _ = ret.send(Err(ServerError::internal().context(e)));
  99. },
  100. }
  101. });
  102. rx.await?
  103. }
  104. pub async fn retry(&self, count: usize) -> Result<(), ServerError> {
  105. if self.sender_ctrl.read().is_connecting() {
  106. return Ok(());
  107. }
  108. let strategy = FixedInterval::from_millis(5000).take(count);
  109. let addr = self
  110. .addr
  111. .read()
  112. .as_ref()
  113. .expect("must call start_connect first")
  114. .clone();
  115. self.connect(addr, strategy).await
  116. }
  117. pub fn state_subscribe(&self) -> broadcast::Receiver<WsConnectState> { self.state_notify.subscribe() }
  118. pub fn sender(&self) -> Result<Arc<WsSender>, WsError> {
  119. match self.sender_ctrl.read().sender() {
  120. None => Err(WsError::internal().context("WsSender is not initialized, should call connect first")),
  121. Some(sender) => Ok(sender),
  122. }
  123. }
  124. }
  125. async fn spawn_stream_and_handlers(
  126. stream: WsStream,
  127. handlers: WsHandlerFuture,
  128. sender_ctrl: Arc<RwLock<WsSenderController>>,
  129. ) {
  130. tokio::select! {
  131. result = stream => {
  132. if let Err(e) = result {
  133. sender_ctrl.write().set_error(e);
  134. }
  135. },
  136. result = handlers => tracing::debug!("handlers completed {:?}", result),
  137. };
  138. }
  139. #[pin_project]
  140. pub struct WsHandlerFuture {
  141. #[pin]
  142. msg_rx: MsgReceiver,
  143. // Opti: Hashmap would be better
  144. handlers: Handlers,
  145. }
  146. impl WsHandlerFuture {
  147. fn new(handlers: Handlers, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
  148. fn handler_ws_message(&self, message: Message) {
  149. if let Message::Binary(bytes) = message {
  150. self.handle_binary_message(bytes)
  151. }
  152. }
  153. fn handle_binary_message(&self, bytes: Vec<u8>) {
  154. let bytes = Bytes::from(bytes);
  155. match WsMessage::try_from(bytes) {
  156. Ok(message) => match self.handlers.get(&message.module) {
  157. None => log::error!("Can't find any handler for message: {:?}", message),
  158. Some(handler) => handler.receive_message(message.clone()),
  159. },
  160. Err(e) => {
  161. log::error!("Deserialize binary ws message failed: {:?}", e);
  162. },
  163. }
  164. }
  165. }
  166. impl Future for WsHandlerFuture {
  167. type Output = ();
  168. fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  169. loop {
  170. match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
  171. None => {
  172. return Poll::Ready(());
  173. },
  174. Some(message) => self.handler_ws_message(message),
  175. }
  176. }
  177. }
  178. }
  179. #[derive(Debug, Clone)]
  180. pub struct WsSender {
  181. ws_tx: MsgSender,
  182. }
  183. impl WsSender {
  184. pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
  185. let msg = msg.into();
  186. let _ = self
  187. .ws_tx
  188. .unbounded_send(msg.into())
  189. .map_err(|e| WsError::internal().context(e))?;
  190. Ok(())
  191. }
  192. pub fn send_text(&self, source: &WsModule, text: &str) -> Result<(), WsError> {
  193. let msg = WsMessage {
  194. module: source.clone(),
  195. data: text.as_bytes().to_vec(),
  196. };
  197. self.send_msg(msg)
  198. }
  199. pub fn send_binary(&self, source: &WsModule, bytes: Vec<u8>) -> Result<(), WsError> {
  200. let msg = WsMessage {
  201. module: source.clone(),
  202. data: bytes,
  203. };
  204. self.send_msg(msg)
  205. }
  206. pub fn send_disconnect(&self, reason: &str) -> Result<(), WsError> {
  207. let frame = CloseFrame {
  208. code: CloseCode::Normal,
  209. reason: reason.to_owned().into(),
  210. };
  211. let msg = Message::Close(Some(frame));
  212. let _ = self
  213. .ws_tx
  214. .unbounded_send(msg)
  215. .map_err(|e| WsError::internal().context(e))?;
  216. Ok(())
  217. }
  218. }
  219. struct WsConnectAction {
  220. addr: String,
  221. handlers: Handlers,
  222. }
  223. impl Action for WsConnectAction {
  224. type Future = Pin<Box<dyn Future<Output = Result<Self::Item, Self::Error>> + Send + Sync>>;
  225. type Item = WsConnectResult;
  226. type Error = WsError;
  227. fn run(&mut self) -> Self::Future {
  228. let addr = self.addr.clone();
  229. let handlers = self.handlers.clone();
  230. Box::pin(WsConnectActionFut::new(addr, handlers))
  231. }
  232. }
  233. struct WsConnectResult {
  234. stream: WsStream,
  235. handlers_fut: WsHandlerFuture,
  236. sender: WsSender,
  237. }
  238. #[pin_project]
  239. struct WsConnectActionFut {
  240. addr: String,
  241. #[pin]
  242. conn: WsConnectionFuture,
  243. handlers_fut: Option<WsHandlerFuture>,
  244. sender: Option<WsSender>,
  245. }
  246. impl WsConnectActionFut {
  247. fn new(addr: String, handlers: Handlers) -> Self {
  248. // Stream User
  249. // ┌───────────────┐ ┌──────────────┐
  250. // ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
  251. // │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │ │
  252. // └──────┘ │ └─────────┘ │ └────────┘ │ └────────┘ │
  253. // ▲ │ │ │ │
  254. // │ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
  255. // └─────────┼──│ws_write │◀─┼────│ ws_rx │◀──┼──│ ws_tx │ │
  256. // │ └─────────┘ │ └────────┘ │ └────────┘ │
  257. // └───────────────┘ └──────────────┘
  258. let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
  259. let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
  260. let sender = WsSender { ws_tx };
  261. let handlers_fut = WsHandlerFuture::new(handlers, msg_rx);
  262. let conn = WsConnectionFuture::new(msg_tx, ws_rx, addr.clone());
  263. Self {
  264. addr,
  265. conn,
  266. handlers_fut: Some(handlers_fut),
  267. sender: Some(sender),
  268. }
  269. }
  270. }
  271. impl Future for WsConnectActionFut {
  272. type Output = Result<WsConnectResult, WsError>;
  273. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  274. let mut this = self.project();
  275. match ready!(this.conn.as_mut().poll(cx)) {
  276. Ok(stream) => {
  277. let handlers_fut = this.handlers_fut.take().expect("Only take once");
  278. let sender = this.sender.take().expect("Only take once");
  279. Poll::Ready(Ok(WsConnectResult {
  280. stream,
  281. handlers_fut,
  282. sender,
  283. }))
  284. },
  285. Err(e) => Poll::Ready(Err(e)),
  286. }
  287. }
  288. }
  289. #[derive(Clone, Eq, PartialEq)]
  290. pub enum WsConnectState {
  291. Init,
  292. Connecting,
  293. Connected,
  294. Disconnected,
  295. }
  296. impl std::fmt::Display for WsConnectState {
  297. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  298. match self {
  299. WsConnectState::Init => f.write_str("Init"),
  300. WsConnectState::Connected => f.write_str("Connecting"),
  301. WsConnectState::Connecting => f.write_str("Connected"),
  302. WsConnectState::Disconnected => f.write_str("Disconnected"),
  303. }
  304. }
  305. }
  306. impl std::fmt::Debug for WsConnectState {
  307. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_str(&format!("{}", self)) }
  308. }
  309. struct WsSenderController {
  310. state: WsConnectState,
  311. state_notify: Arc<broadcast::Sender<WsConnectState>>,
  312. sender: Option<Arc<WsSender>>,
  313. }
  314. impl WsSenderController {
  315. fn set_sender(&mut self, sender: WsSender) { self.sender = Some(Arc::new(sender)); }
  316. fn set_state(&mut self, state: WsConnectState) {
  317. if state != WsConnectState::Connected {
  318. self.sender = None;
  319. }
  320. self.state = state.clone();
  321. let _ = self.state_notify.send(state);
  322. }
  323. fn set_error(&mut self, error: WsError) {
  324. log::error!("{:?}", error);
  325. self.set_state(WsConnectState::Disconnected);
  326. }
  327. fn sender(&self) -> Option<Arc<WsSender>> { self.sender.clone() }
  328. fn is_connecting(&self) -> bool { self.state == WsConnectState::Connecting }
  329. #[allow(dead_code)]
  330. fn is_connected(&self) -> bool { self.state == WsConnectState::Connected }
  331. }
  332. impl std::default::Default for WsSenderController {
  333. fn default() -> Self {
  334. let (state_notify, _) = broadcast::channel(16);
  335. WsSenderController {
  336. state: WsConnectState::Init,
  337. state_notify: Arc::new(state_notify),
  338. sender: None,
  339. }
  340. }
  341. }