ws.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. #![allow(clippy::type_complexity)]
  2. use crate::{
  3. connect::{WSConnectionFuture, WSStream},
  4. errors::WSError,
  5. WSChannel, WebSocketRawMessage,
  6. };
  7. use dashmap::DashMap;
  8. use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
  9. use futures_core::{ready, Stream};
  10. use lib_infra::retry::{Action, FixedInterval, Retry};
  11. use pin_project::pin_project;
  12. use std::{
  13. fmt::Formatter,
  14. future::Future,
  15. pin::Pin,
  16. sync::Arc,
  17. task::{Context, Poll},
  18. time::Duration,
  19. };
  20. use tokio::sync::{broadcast, oneshot, RwLock};
  21. use tokio_tungstenite::tungstenite::{
  22. protocol::{frame::coding::CloseCode, CloseFrame},
  23. Message,
  24. };
  25. pub type MsgReceiver = UnboundedReceiver<Message>;
  26. pub type MsgSender = UnboundedSender<Message>;
  27. type Handlers = DashMap<WSChannel, Arc<dyn WSMessageReceiver>>;
  28. pub trait WSMessageReceiver: Sync + Send + 'static {
  29. fn source(&self) -> WSChannel;
  30. fn receive_message(&self, msg: WebSocketRawMessage);
  31. }
  32. pub struct WSController {
  33. handlers: Handlers,
  34. addr: Arc<RwLock<Option<String>>>,
  35. sender: Arc<RwLock<Option<Arc<WSSender>>>>,
  36. conn_state_notify: Arc<RwLock<WSConnectStateNotifier>>,
  37. }
  38. impl std::fmt::Display for WSController {
  39. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  40. f.write_str("WebSocket")
  41. }
  42. }
  43. impl std::default::Default for WSController {
  44. fn default() -> Self {
  45. Self {
  46. handlers: DashMap::new(),
  47. addr: Arc::new(RwLock::new(None)),
  48. sender: Arc::new(RwLock::new(None)),
  49. conn_state_notify: Arc::new(RwLock::new(WSConnectStateNotifier::default())),
  50. }
  51. }
  52. }
  53. impl WSController {
  54. pub fn new() -> Self {
  55. WSController::default()
  56. }
  57. pub fn add_ws_message_receiver(
  58. &self,
  59. handler: Arc<dyn WSMessageReceiver>,
  60. ) -> Result<(), WSError> {
  61. let source = handler.source();
  62. if self.handlers.contains_key(&source) {
  63. log::error!("{:?} is already registered", source);
  64. }
  65. self.handlers.insert(source, handler);
  66. Ok(())
  67. }
  68. pub async fn start(&self, addr: String) -> Result<(), WSError> {
  69. *self.addr.write().await = Some(addr.clone());
  70. let strategy = FixedInterval::from_millis(5000).take(3);
  71. self.connect(addr, strategy).await
  72. }
  73. pub async fn stop(&self) {
  74. if self
  75. .conn_state_notify
  76. .read()
  77. .await
  78. .conn_state
  79. .is_connected()
  80. {
  81. tracing::trace!("[{}] stop", self);
  82. self
  83. .conn_state_notify
  84. .write()
  85. .await
  86. .update_state(WSConnectState::Disconnected);
  87. }
  88. }
  89. async fn connect<T, I>(&self, addr: String, strategy: T) -> Result<(), WSError>
  90. where
  91. T: IntoIterator<IntoIter = I, Item = Duration>,
  92. I: Iterator<Item = Duration> + Send + 'static,
  93. {
  94. let mut conn_state_notify = self.conn_state_notify.write().await;
  95. let conn_state = conn_state_notify.conn_state.clone();
  96. if conn_state.is_connected() || conn_state.is_connecting() {
  97. return Ok(());
  98. }
  99. let (ret, rx) = oneshot::channel::<Result<(), WSError>>();
  100. *self.addr.write().await = Some(addr.clone());
  101. let action = WSConnectAction {
  102. addr,
  103. handlers: self.handlers.clone(),
  104. };
  105. let retry = Retry::new(strategy, action);
  106. conn_state_notify.update_state(WSConnectState::Connecting);
  107. drop(conn_state_notify);
  108. let cloned_conn_state = self.conn_state_notify.clone();
  109. let cloned_sender = self.sender.clone();
  110. tracing::trace!("[{}] start connecting", self);
  111. tokio::spawn(async move {
  112. match retry.await {
  113. Ok(result) => {
  114. let WSConnectResult {
  115. stream,
  116. handlers_fut,
  117. sender,
  118. } = result;
  119. cloned_conn_state
  120. .write()
  121. .await
  122. .update_state(WSConnectState::Connected);
  123. *cloned_sender.write().await = Some(Arc::new(sender));
  124. let _ = ret.send(Ok(()));
  125. spawn_stream_and_handlers(stream, handlers_fut).await;
  126. },
  127. Err(e) => {
  128. cloned_conn_state
  129. .write()
  130. .await
  131. .update_state(WSConnectState::Disconnected);
  132. let _ = ret.send(Err(WSError::internal().context(e)));
  133. },
  134. }
  135. });
  136. rx.await?
  137. }
  138. pub async fn retry(&self, count: usize) -> Result<(), WSError> {
  139. if !self
  140. .conn_state_notify
  141. .read()
  142. .await
  143. .conn_state
  144. .is_disconnected()
  145. {
  146. return Ok(());
  147. }
  148. tracing::trace!("[WebSocket]: retry connect...");
  149. let strategy = FixedInterval::from_millis(5000).take(count);
  150. let addr = self
  151. .addr
  152. .read()
  153. .await
  154. .as_ref()
  155. .expect("Retry web socket connection failed, should call start_connect first")
  156. .clone();
  157. self.connect(addr, strategy).await
  158. }
  159. pub async fn subscribe_state(&self) -> broadcast::Receiver<WSConnectState> {
  160. self.conn_state_notify.read().await.notify.subscribe()
  161. }
  162. pub async fn ws_message_sender(&self) -> Result<Option<Arc<WSSender>>, WSError> {
  163. let sender = self.sender.read().await.clone();
  164. match sender {
  165. None => match self.conn_state_notify.read().await.conn_state {
  166. WSConnectState::Disconnected => {
  167. let msg = "WebSocket is disconnected";
  168. Err(WSError::internal().context(msg))
  169. },
  170. _ => Ok(None),
  171. },
  172. Some(sender) => Ok(Some(sender)),
  173. }
  174. }
  175. }
  176. async fn spawn_stream_and_handlers(stream: WSStream, handlers: WSHandlerFuture) {
  177. tokio::select! {
  178. result = stream => {
  179. if let Err(e) = result {
  180. tracing::error!("WSStream error: {:?}", e);
  181. }
  182. },
  183. result = handlers => tracing::debug!("handlers completed {:?}", result),
  184. };
  185. }
  186. #[pin_project]
  187. pub struct WSHandlerFuture {
  188. #[pin]
  189. msg_rx: MsgReceiver,
  190. handlers: Handlers,
  191. }
  192. impl WSHandlerFuture {
  193. fn new(handlers: Handlers, msg_rx: MsgReceiver) -> Self {
  194. Self { msg_rx, handlers }
  195. }
  196. fn handler_ws_message(&self, message: Message) {
  197. if let Message::Binary(bytes) = message {
  198. self.handle_binary_message(bytes)
  199. }
  200. }
  201. fn handle_binary_message(&self, bytes: Vec<u8>) {
  202. let msg = WebSocketRawMessage::from_bytes(bytes);
  203. match self.handlers.get(&msg.channel) {
  204. None => log::error!("Can't find any handler for message: {:?}", msg),
  205. Some(handler) => handler.receive_message(msg),
  206. }
  207. }
  208. }
  209. impl Future for WSHandlerFuture {
  210. type Output = ();
  211. fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  212. loop {
  213. match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
  214. None => {
  215. return Poll::Ready(());
  216. },
  217. Some(message) => self.handler_ws_message(message),
  218. }
  219. }
  220. }
  221. }
  222. #[derive(Debug, Clone)]
  223. pub struct WSSender(MsgSender);
  224. impl WSSender {
  225. pub fn send_msg<T: Into<WebSocketRawMessage>>(&self, msg: T) -> Result<(), WSError> {
  226. let msg = msg.into();
  227. self
  228. .0
  229. .unbounded_send(msg.into())
  230. .map_err(|e| WSError::internal().context(e))?;
  231. Ok(())
  232. }
  233. pub fn send_text(&self, source: &WSChannel, text: &str) -> Result<(), WSError> {
  234. let msg = WebSocketRawMessage {
  235. channel: source.clone(),
  236. data: text.as_bytes().to_vec(),
  237. };
  238. self.send_msg(msg)
  239. }
  240. pub fn send_binary(&self, source: &WSChannel, bytes: Vec<u8>) -> Result<(), WSError> {
  241. let msg = WebSocketRawMessage {
  242. channel: source.clone(),
  243. data: bytes,
  244. };
  245. self.send_msg(msg)
  246. }
  247. pub fn send_disconnect(&self, reason: &str) -> Result<(), WSError> {
  248. let frame = CloseFrame {
  249. code: CloseCode::Normal,
  250. reason: reason.to_owned().into(),
  251. };
  252. let msg = Message::Close(Some(frame));
  253. self
  254. .0
  255. .unbounded_send(msg)
  256. .map_err(|e| WSError::internal().context(e))?;
  257. Ok(())
  258. }
  259. }
  260. struct WSConnectAction {
  261. addr: String,
  262. handlers: Handlers,
  263. }
  264. impl Action for WSConnectAction {
  265. type Future = Pin<Box<dyn Future<Output = Result<Self::Item, Self::Error>> + Send + Sync>>;
  266. type Item = WSConnectResult;
  267. type Error = WSError;
  268. fn run(&mut self) -> Self::Future {
  269. let addr = self.addr.clone();
  270. let handlers = self.handlers.clone();
  271. Box::pin(WSConnectActionFut::new(addr, handlers))
  272. }
  273. }
  274. struct WSConnectResult {
  275. stream: WSStream,
  276. handlers_fut: WSHandlerFuture,
  277. sender: WSSender,
  278. }
  279. #[pin_project]
  280. struct WSConnectActionFut {
  281. addr: String,
  282. #[pin]
  283. conn: WSConnectionFuture,
  284. handlers_fut: Option<WSHandlerFuture>,
  285. sender: Option<WSSender>,
  286. }
  287. impl WSConnectActionFut {
  288. fn new(addr: String, handlers: Handlers) -> Self {
  289. // Stream User
  290. // ┌───────────────┐ ┌──────────────┐
  291. // ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
  292. // │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │ │
  293. // └──────┘ │ └─────────┘ │ └────────┘ │ └────────┘ │
  294. // ▲ │ │ │ │
  295. // │ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
  296. // └─────────┼──│ws_write │◀─┼────│ ws_rx │◀──┼──│ ws_tx │ │
  297. // │ └─────────┘ │ └────────┘ │ └────────┘ │
  298. // └───────────────┘ └──────────────┘
  299. let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
  300. let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
  301. let sender = WSSender(ws_tx);
  302. let handlers_fut = WSHandlerFuture::new(handlers, msg_rx);
  303. let conn = WSConnectionFuture::new(msg_tx, ws_rx, addr.clone());
  304. Self {
  305. addr,
  306. conn,
  307. handlers_fut: Some(handlers_fut),
  308. sender: Some(sender),
  309. }
  310. }
  311. }
  312. impl Future for WSConnectActionFut {
  313. type Output = Result<WSConnectResult, WSError>;
  314. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  315. let mut this = self.project();
  316. match ready!(this.conn.as_mut().poll(cx)) {
  317. Ok(stream) => {
  318. let handlers_fut = this.handlers_fut.take().expect("Only take once");
  319. let sender = this.sender.take().expect("Only take once");
  320. Poll::Ready(Ok(WSConnectResult {
  321. stream,
  322. handlers_fut,
  323. sender,
  324. }))
  325. },
  326. Err(e) => Poll::Ready(Err(e)),
  327. }
  328. }
  329. }
  330. #[derive(Clone, Eq, PartialEq)]
  331. pub enum WSConnectState {
  332. Init,
  333. Connecting,
  334. Connected,
  335. Disconnected,
  336. }
  337. impl WSConnectState {
  338. fn is_connected(&self) -> bool {
  339. self == &WSConnectState::Connected
  340. }
  341. fn is_connecting(&self) -> bool {
  342. self == &WSConnectState::Connecting
  343. }
  344. fn is_disconnected(&self) -> bool {
  345. self == &WSConnectState::Disconnected || self == &WSConnectState::Init
  346. }
  347. }
  348. impl std::fmt::Display for WSConnectState {
  349. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  350. match self {
  351. WSConnectState::Init => f.write_str("Init"),
  352. WSConnectState::Connected => f.write_str("Connected"),
  353. WSConnectState::Connecting => f.write_str("Connecting"),
  354. WSConnectState::Disconnected => f.write_str("Disconnected"),
  355. }
  356. }
  357. }
  358. impl std::fmt::Debug for WSConnectState {
  359. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  360. f.write_str(&format!("{}", self))
  361. }
  362. }
  363. struct WSConnectStateNotifier {
  364. conn_state: WSConnectState,
  365. notify: Arc<broadcast::Sender<WSConnectState>>,
  366. }
  367. impl std::default::Default for WSConnectStateNotifier {
  368. fn default() -> Self {
  369. let (state_notify, _) = broadcast::channel(16);
  370. Self {
  371. conn_state: WSConnectState::Init,
  372. notify: Arc::new(state_notify),
  373. }
  374. }
  375. }
  376. impl WSConnectStateNotifier {
  377. fn update_state(&mut self, new_state: WSConnectState) {
  378. if self.conn_state == new_state {
  379. return;
  380. }
  381. tracing::debug!(
  382. "WebSocket connect state did change: {} -> {}",
  383. self.conn_state,
  384. new_state
  385. );
  386. self.conn_state = new_state.clone();
  387. let _ = self.notify.send(new_state);
  388. }
  389. }