ws.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. use crate::{
  2. connect::{WsConnectionFuture, WsStream},
  3. errors::WsError,
  4. WsMessage,
  5. WsModule,
  6. };
  7. use backend_service::errors::ServerError;
  8. use bytes::Bytes;
  9. use dashmap::DashMap;
  10. use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
  11. use futures_core::{ready, Stream};
  12. use lib_infra::retry::{Action, FixedInterval, Retry};
  13. use parking_lot::RwLock;
  14. use pin_project::pin_project;
  15. use std::{
  16. convert::TryFrom,
  17. fmt::Formatter,
  18. future::Future,
  19. pin::Pin,
  20. sync::Arc,
  21. task::{Context, Poll},
  22. time::Duration,
  23. };
  24. use tokio::sync::{broadcast, oneshot};
  25. use tokio_tungstenite::tungstenite::{
  26. protocol::{frame::coding::CloseCode, CloseFrame},
  27. Message,
  28. };
  29. pub type MsgReceiver = UnboundedReceiver<Message>;
  30. pub type MsgSender = UnboundedSender<Message>;
  31. type Handlers = DashMap<WsModule, Arc<dyn WsMessageHandler>>;
  32. pub trait WsMessageHandler: Sync + Send + 'static {
  33. fn source(&self) -> WsModule;
  34. fn receive_message(&self, msg: WsMessage);
  35. }
  36. #[derive(Clone)]
  37. pub enum WsState {
  38. Init,
  39. Connected(Arc<WsSender>),
  40. Disconnected(WsError),
  41. }
  42. impl std::fmt::Display for WsState {
  43. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  44. match self {
  45. WsState::Init => f.write_str("Init"),
  46. WsState::Connected(_) => f.write_str("Connected"),
  47. WsState::Disconnected(_) => f.write_str("Disconnected"),
  48. }
  49. }
  50. }
  51. impl std::fmt::Debug for WsState {
  52. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_str(&format!("{}", self)) }
  53. }
  54. pub struct WsController {
  55. handlers: Handlers,
  56. state_notify: Arc<broadcast::Sender<WsState>>,
  57. sender: Arc<RwLock<Option<Arc<WsSender>>>>,
  58. addr: Arc<RwLock<Option<String>>>,
  59. }
  60. impl WsController {
  61. pub fn new() -> Self {
  62. let (state_notify, _) = broadcast::channel(16);
  63. let controller = Self {
  64. handlers: DashMap::new(),
  65. sender: Arc::new(RwLock::new(None)),
  66. state_notify: Arc::new(state_notify),
  67. addr: Arc::new(RwLock::new(None)),
  68. };
  69. controller
  70. }
  71. pub fn add_handler(&self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
  72. let source = handler.source();
  73. if self.handlers.contains_key(&source) {
  74. log::error!("WsSource's {:?} is already registered", source);
  75. }
  76. self.handlers.insert(source, handler);
  77. Ok(())
  78. }
  79. pub async fn start_connect(&self, addr: String) -> Result<(), ServerError> {
  80. *self.addr.write() = Some(addr.clone());
  81. let strategy = FixedInterval::from_millis(5000).take(3);
  82. self.connect(addr, strategy).await
  83. }
  84. async fn connect<T, I>(&self, addr: String, strategy: T) -> Result<(), ServerError>
  85. where
  86. T: IntoIterator<IntoIter = I, Item = Duration>,
  87. I: Iterator<Item = Duration> + Send + 'static,
  88. {
  89. let (ret, rx) = oneshot::channel::<Result<(), ServerError>>();
  90. *self.addr.write() = Some(addr.clone());
  91. let action = WsConnectAction {
  92. addr,
  93. handlers: self.handlers.clone(),
  94. };
  95. let retry = Retry::spawn(strategy, action);
  96. let sender_holder = self.sender.clone();
  97. let state_notify = self.state_notify.clone();
  98. tokio::spawn(async move {
  99. match retry.await {
  100. Ok(result) => {
  101. let WsConnectResult {
  102. stream,
  103. handlers_fut,
  104. sender,
  105. } = result;
  106. let sender = Arc::new(sender);
  107. *sender_holder.write() = Some(sender.clone());
  108. let _ = state_notify.send(WsState::Connected(sender));
  109. let _ = ret.send(Ok(()));
  110. spawn_stream_and_handlers(stream, handlers_fut, state_notify).await;
  111. },
  112. Err(e) => {
  113. //
  114. let _ = state_notify.send(WsState::Disconnected(e.clone()));
  115. let _ = ret.send(Err(ServerError::internal().context(e)));
  116. },
  117. }
  118. });
  119. rx.await?
  120. }
  121. pub async fn retry(&self) -> Result<(), ServerError> {
  122. let addr = self
  123. .addr
  124. .read()
  125. .as_ref()
  126. .expect("must call start_connect first")
  127. .clone();
  128. let strategy = FixedInterval::from_millis(5000);
  129. self.connect(addr, strategy).await
  130. }
  131. pub fn state_subscribe(&self) -> broadcast::Receiver<WsState> { self.state_notify.subscribe() }
  132. pub fn sender(&self) -> Result<Arc<WsSender>, WsError> {
  133. match &*self.sender.read() {
  134. None => Err(WsError::internal().context("WsSender is not initialized, should call connect first")),
  135. Some(sender) => Ok(sender.clone()),
  136. }
  137. }
  138. }
  139. async fn spawn_stream_and_handlers(
  140. stream: WsStream,
  141. handlers: WsHandlerFuture,
  142. state_notify: Arc<broadcast::Sender<WsState>>,
  143. ) {
  144. tokio::select! {
  145. result = stream => {
  146. match result {
  147. Ok(_) => {},
  148. Err(e) => {
  149. log::error!("websocket error: {:?}", e);
  150. let _ = state_notify.send(WsState::Disconnected(e)).unwrap();
  151. }
  152. }
  153. },
  154. result = handlers => tracing::debug!("handlers completed {:?}", result),
  155. };
  156. }
  157. #[pin_project]
  158. pub struct WsHandlerFuture {
  159. #[pin]
  160. msg_rx: MsgReceiver,
  161. // Opti: Hashmap would be better
  162. handlers: Handlers,
  163. }
  164. impl WsHandlerFuture {
  165. fn new(handlers: Handlers, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
  166. fn handler_ws_message(&self, message: Message) {
  167. match message {
  168. Message::Binary(bytes) => self.handle_binary_message(bytes),
  169. _ => {},
  170. }
  171. }
  172. fn handle_binary_message(&self, bytes: Vec<u8>) {
  173. let bytes = Bytes::from(bytes);
  174. match WsMessage::try_from(bytes) {
  175. Ok(message) => match self.handlers.get(&message.module) {
  176. None => log::error!("Can't find any handler for message: {:?}", message),
  177. Some(handler) => handler.receive_message(message.clone()),
  178. },
  179. Err(e) => {
  180. log::error!("Deserialize binary ws message failed: {:?}", e);
  181. },
  182. }
  183. }
  184. }
  185. impl Future for WsHandlerFuture {
  186. type Output = ();
  187. fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  188. loop {
  189. match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
  190. None => {
  191. return Poll::Ready(());
  192. },
  193. Some(message) => self.handler_ws_message(message),
  194. }
  195. }
  196. }
  197. }
  198. #[derive(Debug, Clone)]
  199. pub struct WsSender {
  200. ws_tx: MsgSender,
  201. }
  202. impl WsSender {
  203. pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
  204. let msg = msg.into();
  205. let _ = self
  206. .ws_tx
  207. .unbounded_send(msg.into())
  208. .map_err(|e| WsError::internal().context(e))?;
  209. Ok(())
  210. }
  211. pub fn send_text(&self, source: &WsModule, text: &str) -> Result<(), WsError> {
  212. let msg = WsMessage {
  213. module: source.clone(),
  214. data: text.as_bytes().to_vec(),
  215. };
  216. self.send_msg(msg)
  217. }
  218. pub fn send_binary(&self, source: &WsModule, bytes: Vec<u8>) -> Result<(), WsError> {
  219. let msg = WsMessage {
  220. module: source.clone(),
  221. data: bytes,
  222. };
  223. self.send_msg(msg)
  224. }
  225. pub fn send_disconnect(&self, reason: &str) -> Result<(), WsError> {
  226. let frame = CloseFrame {
  227. code: CloseCode::Normal,
  228. reason: reason.to_owned().into(),
  229. };
  230. let msg = Message::Close(Some(frame));
  231. let _ = self
  232. .ws_tx
  233. .unbounded_send(msg)
  234. .map_err(|e| WsError::internal().context(e))?;
  235. Ok(())
  236. }
  237. }
  238. struct WsConnectAction {
  239. addr: String,
  240. handlers: Handlers,
  241. }
  242. struct WsConnectResult {
  243. stream: WsStream,
  244. handlers_fut: WsHandlerFuture,
  245. sender: WsSender,
  246. }
  247. #[pin_project]
  248. struct WsConnectActionFut {
  249. addr: String,
  250. #[pin]
  251. conn: WsConnectionFuture,
  252. handlers_fut: Option<WsHandlerFuture>,
  253. sender: Option<WsSender>,
  254. }
  255. impl WsConnectActionFut {
  256. fn new(addr: String, handlers: Handlers) -> Self {
  257. // Stream User
  258. // ┌───────────────┐ ┌──────────────┐
  259. // ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
  260. // │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │ │
  261. // └──────┘ │ └─────────┘ │ └────────┘ │ └────────┘ │
  262. // ▲ │ │ │ │
  263. // │ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
  264. // └─────────┼──│ws_write │◀─┼────│ ws_rx │◀──┼──│ ws_tx │ │
  265. // │ └─────────┘ │ └────────┘ │ └────────┘ │
  266. // └───────────────┘ └──────────────┘
  267. let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
  268. let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
  269. let sender = WsSender { ws_tx };
  270. let handlers_fut = WsHandlerFuture::new(handlers, msg_rx);
  271. let conn = WsConnectionFuture::new(msg_tx, ws_rx, addr.clone());
  272. Self {
  273. addr,
  274. conn,
  275. handlers_fut: Some(handlers_fut),
  276. sender: Some(sender),
  277. }
  278. }
  279. }
  280. impl Future for WsConnectActionFut {
  281. type Output = Result<WsConnectResult, WsError>;
  282. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  283. let mut this = self.project();
  284. match ready!(this.conn.as_mut().poll(cx)) {
  285. Ok(stream) => {
  286. let handlers_fut = this.handlers_fut.take().expect("Only take once");
  287. let sender = this.sender.take().expect("Only take once");
  288. Poll::Ready(Ok(WsConnectResult {
  289. stream,
  290. handlers_fut,
  291. sender,
  292. }))
  293. },
  294. Err(e) => Poll::Ready(Err(e)),
  295. }
  296. }
  297. }
  298. impl Action for WsConnectAction {
  299. type Future = Pin<Box<dyn Future<Output = Result<Self::Item, Self::Error>> + Send + Sync>>;
  300. type Item = WsConnectResult;
  301. type Error = WsError;
  302. fn run(&mut self) -> Self::Future {
  303. let addr = self.addr.clone();
  304. let handlers = self.handlers.clone();
  305. Box::pin(WsConnectActionFut::new(addr, handlers))
  306. }
  307. }