ws.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. #![allow(clippy::type_complexity)]
  2. use crate::{
  3. connect::{WsConnectionFuture, WsStream},
  4. errors::WsError,
  5. WsMessage,
  6. WsModule,
  7. };
  8. use backend_service::errors::ServerError;
  9. use bytes::Bytes;
  10. use dashmap::DashMap;
  11. use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
  12. use futures_core::{ready, Stream};
  13. use lib_infra::retry::{Action, FixedInterval, Retry};
  14. use parking_lot::RwLock;
  15. use pin_project::pin_project;
  16. use std::{
  17. convert::TryFrom,
  18. fmt::Formatter,
  19. future::Future,
  20. pin::Pin,
  21. sync::Arc,
  22. task::{Context, Poll},
  23. time::Duration,
  24. };
  25. use tokio::sync::{broadcast, oneshot};
  26. use tokio_tungstenite::tungstenite::{
  27. protocol::{frame::coding::CloseCode, CloseFrame},
  28. Message,
  29. };
  30. pub type MsgReceiver = UnboundedReceiver<Message>;
  31. pub type MsgSender = UnboundedSender<Message>;
  32. type Handlers = DashMap<WsModule, Arc<dyn WsMessageHandler>>;
  33. pub trait WsMessageHandler: Sync + Send + 'static {
  34. fn source(&self) -> WsModule;
  35. fn receive_message(&self, msg: WsMessage);
  36. }
  37. #[derive(Clone)]
  38. pub enum WsState {
  39. Init,
  40. Connected(Arc<WsSender>),
  41. Disconnected(WsError),
  42. }
  43. impl std::fmt::Display for WsState {
  44. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  45. match self {
  46. WsState::Init => f.write_str("Init"),
  47. WsState::Connected(_) => f.write_str("Connected"),
  48. WsState::Disconnected(_) => f.write_str("Disconnected"),
  49. }
  50. }
  51. }
  52. impl std::fmt::Debug for WsState {
  53. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_str(&format!("{}", self)) }
  54. }
  55. pub struct WsController {
  56. handlers: Handlers,
  57. state_notify: Arc<broadcast::Sender<WsState>>,
  58. sender: Arc<RwLock<Option<Arc<WsSender>>>>,
  59. addr: Arc<RwLock<Option<String>>>,
  60. }
  61. impl std::default::Default for WsController {
  62. fn default() -> Self {
  63. let (state_notify, _) = broadcast::channel(16);
  64. Self {
  65. handlers: DashMap::new(),
  66. sender: Arc::new(RwLock::new(None)),
  67. state_notify: Arc::new(state_notify),
  68. addr: Arc::new(RwLock::new(None)),
  69. }
  70. }
  71. }
  72. impl WsController {
  73. pub fn new() -> Self { WsController::default() }
  74. pub fn add_handler(&self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
  75. let source = handler.source();
  76. if self.handlers.contains_key(&source) {
  77. log::error!("WsSource's {:?} is already registered", source);
  78. }
  79. self.handlers.insert(source, handler);
  80. Ok(())
  81. }
  82. pub async fn start_connect(&self, addr: String) -> Result<(), ServerError> {
  83. *self.addr.write() = Some(addr.clone());
  84. let strategy = FixedInterval::from_millis(5000).take(3);
  85. self.connect(addr, strategy).await
  86. }
  87. async fn connect<T, I>(&self, addr: String, strategy: T) -> Result<(), ServerError>
  88. where
  89. T: IntoIterator<IntoIter = I, Item = Duration>,
  90. I: Iterator<Item = Duration> + Send + 'static,
  91. {
  92. let (ret, rx) = oneshot::channel::<Result<(), ServerError>>();
  93. *self.addr.write() = Some(addr.clone());
  94. let action = WsConnectAction {
  95. addr,
  96. handlers: self.handlers.clone(),
  97. };
  98. let retry = Retry::spawn(strategy, action);
  99. let sender_holder = self.sender.clone();
  100. let state_notify = self.state_notify.clone();
  101. tokio::spawn(async move {
  102. match retry.await {
  103. Ok(result) => {
  104. let WsConnectResult {
  105. stream,
  106. handlers_fut,
  107. sender,
  108. } = result;
  109. let sender = Arc::new(sender);
  110. *sender_holder.write() = Some(sender.clone());
  111. let _ = state_notify.send(WsState::Connected(sender));
  112. let _ = ret.send(Ok(()));
  113. spawn_stream_and_handlers(stream, handlers_fut, state_notify).await;
  114. },
  115. Err(e) => {
  116. //
  117. let _ = state_notify.send(WsState::Disconnected(e.clone()));
  118. let _ = ret.send(Err(ServerError::internal().context(e)));
  119. },
  120. }
  121. });
  122. rx.await?
  123. }
  124. pub async fn retry(&self) -> Result<(), ServerError> {
  125. let addr = self
  126. .addr
  127. .read()
  128. .as_ref()
  129. .expect("must call start_connect first")
  130. .clone();
  131. let strategy = FixedInterval::from_millis(5000);
  132. self.connect(addr, strategy).await
  133. }
  134. pub fn state_subscribe(&self) -> broadcast::Receiver<WsState> { self.state_notify.subscribe() }
  135. pub fn sender(&self) -> Result<Arc<WsSender>, WsError> {
  136. match &*self.sender.read() {
  137. None => Err(WsError::internal().context("WsSender is not initialized, should call connect first")),
  138. Some(sender) => Ok(sender.clone()),
  139. }
  140. }
  141. }
  142. async fn spawn_stream_and_handlers(
  143. stream: WsStream,
  144. handlers: WsHandlerFuture,
  145. state_notify: Arc<broadcast::Sender<WsState>>,
  146. ) {
  147. tokio::select! {
  148. result = stream => {
  149. match result {
  150. Ok(_) => {},
  151. Err(e) => {
  152. log::error!("websocket error: {:?}", e);
  153. let _ = state_notify.send(WsState::Disconnected(e)).unwrap();
  154. }
  155. }
  156. },
  157. result = handlers => tracing::debug!("handlers completed {:?}", result),
  158. };
  159. }
  160. #[pin_project]
  161. pub struct WsHandlerFuture {
  162. #[pin]
  163. msg_rx: MsgReceiver,
  164. // Opti: Hashmap would be better
  165. handlers: Handlers,
  166. }
  167. impl WsHandlerFuture {
  168. fn new(handlers: Handlers, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
  169. fn handler_ws_message(&self, message: Message) {
  170. if let Message::Binary(bytes) = message {
  171. self.handle_binary_message(bytes)
  172. }
  173. }
  174. fn handle_binary_message(&self, bytes: Vec<u8>) {
  175. let bytes = Bytes::from(bytes);
  176. match WsMessage::try_from(bytes) {
  177. Ok(message) => match self.handlers.get(&message.module) {
  178. None => log::error!("Can't find any handler for message: {:?}", message),
  179. Some(handler) => handler.receive_message(message.clone()),
  180. },
  181. Err(e) => {
  182. log::error!("Deserialize binary ws message failed: {:?}", e);
  183. },
  184. }
  185. }
  186. }
  187. impl Future for WsHandlerFuture {
  188. type Output = ();
  189. fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  190. loop {
  191. match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
  192. None => {
  193. return Poll::Ready(());
  194. },
  195. Some(message) => self.handler_ws_message(message),
  196. }
  197. }
  198. }
  199. }
  200. #[derive(Debug, Clone)]
  201. pub struct WsSender {
  202. ws_tx: MsgSender,
  203. }
  204. impl WsSender {
  205. pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
  206. let msg = msg.into();
  207. let _ = self
  208. .ws_tx
  209. .unbounded_send(msg.into())
  210. .map_err(|e| WsError::internal().context(e))?;
  211. Ok(())
  212. }
  213. pub fn send_text(&self, source: &WsModule, text: &str) -> Result<(), WsError> {
  214. let msg = WsMessage {
  215. module: source.clone(),
  216. data: text.as_bytes().to_vec(),
  217. };
  218. self.send_msg(msg)
  219. }
  220. pub fn send_binary(&self, source: &WsModule, bytes: Vec<u8>) -> Result<(), WsError> {
  221. let msg = WsMessage {
  222. module: source.clone(),
  223. data: bytes,
  224. };
  225. self.send_msg(msg)
  226. }
  227. pub fn send_disconnect(&self, reason: &str) -> Result<(), WsError> {
  228. let frame = CloseFrame {
  229. code: CloseCode::Normal,
  230. reason: reason.to_owned().into(),
  231. };
  232. let msg = Message::Close(Some(frame));
  233. let _ = self
  234. .ws_tx
  235. .unbounded_send(msg)
  236. .map_err(|e| WsError::internal().context(e))?;
  237. Ok(())
  238. }
  239. }
  240. struct WsConnectAction {
  241. addr: String,
  242. handlers: Handlers,
  243. }
  244. struct WsConnectResult {
  245. stream: WsStream,
  246. handlers_fut: WsHandlerFuture,
  247. sender: WsSender,
  248. }
  249. #[pin_project]
  250. struct WsConnectActionFut {
  251. addr: String,
  252. #[pin]
  253. conn: WsConnectionFuture,
  254. handlers_fut: Option<WsHandlerFuture>,
  255. sender: Option<WsSender>,
  256. }
  257. impl WsConnectActionFut {
  258. fn new(addr: String, handlers: Handlers) -> Self {
  259. // Stream User
  260. // ┌───────────────┐ ┌──────────────┐
  261. // ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
  262. // │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │ │
  263. // └──────┘ │ └─────────┘ │ └────────┘ │ └────────┘ │
  264. // ▲ │ │ │ │
  265. // │ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
  266. // └─────────┼──│ws_write │◀─┼────│ ws_rx │◀──┼──│ ws_tx │ │
  267. // │ └─────────┘ │ └────────┘ │ └────────┘ │
  268. // └───────────────┘ └──────────────┘
  269. let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
  270. let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
  271. let sender = WsSender { ws_tx };
  272. let handlers_fut = WsHandlerFuture::new(handlers, msg_rx);
  273. let conn = WsConnectionFuture::new(msg_tx, ws_rx, addr.clone());
  274. Self {
  275. addr,
  276. conn,
  277. handlers_fut: Some(handlers_fut),
  278. sender: Some(sender),
  279. }
  280. }
  281. }
  282. impl Future for WsConnectActionFut {
  283. type Output = Result<WsConnectResult, WsError>;
  284. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  285. let mut this = self.project();
  286. match ready!(this.conn.as_mut().poll(cx)) {
  287. Ok(stream) => {
  288. let handlers_fut = this.handlers_fut.take().expect("Only take once");
  289. let sender = this.sender.take().expect("Only take once");
  290. Poll::Ready(Ok(WsConnectResult {
  291. stream,
  292. handlers_fut,
  293. sender,
  294. }))
  295. },
  296. Err(e) => Poll::Ready(Err(e)),
  297. }
  298. }
  299. }
  300. impl Action for WsConnectAction {
  301. // noinspection RsExternalLinter
  302. type Future = Pin<Box<dyn Future<Output = Result<Self::Item, Self::Error>> + Send + Sync>>;
  303. type Item = WsConnectResult;
  304. type Error = WsError;
  305. fn run(&mut self) -> Self::Future {
  306. let addr = self.addr.clone();
  307. let handlers = self.handlers.clone();
  308. Box::pin(WsConnectActionFut::new(addr, handlers))
  309. }
  310. }