ws.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. #![allow(clippy::type_complexity)]
  2. use crate::{
  3. connect::{WSConnectionFuture, WSStream},
  4. errors::WSError,
  5. WSChannel, WebSocketRawMessage,
  6. };
  7. use bytes::Bytes;
  8. use dashmap::DashMap;
  9. use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
  10. use futures_core::{ready, Stream};
  11. use lib_infra::retry::{Action, FixedInterval, Retry};
  12. use pin_project::pin_project;
  13. use std::{
  14. convert::TryFrom,
  15. fmt::Formatter,
  16. future::Future,
  17. pin::Pin,
  18. sync::Arc,
  19. task::{Context, Poll},
  20. time::Duration,
  21. };
  22. use tokio::sync::{broadcast, oneshot, RwLock};
  23. use tokio_tungstenite::tungstenite::{
  24. protocol::{frame::coding::CloseCode, CloseFrame},
  25. Message,
  26. };
  27. pub type MsgReceiver = UnboundedReceiver<Message>;
  28. pub type MsgSender = UnboundedSender<Message>;
  29. type Handlers = DashMap<WSChannel, Arc<dyn WSMessageReceiver>>;
  30. pub trait WSMessageReceiver: Sync + Send + 'static {
  31. fn source(&self) -> WSChannel;
  32. fn receive_message(&self, msg: WebSocketRawMessage);
  33. }
  34. pub struct WSController {
  35. handlers: Handlers,
  36. addr: Arc<RwLock<Option<String>>>,
  37. sender: Arc<RwLock<Option<Arc<WSSender>>>>,
  38. conn_state_notify: Arc<RwLock<WSConnectStateNotifier>>,
  39. }
  40. impl std::fmt::Display for WSController {
  41. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  42. f.write_str("WebSocket")
  43. }
  44. }
  45. impl std::default::Default for WSController {
  46. fn default() -> Self {
  47. Self {
  48. handlers: DashMap::new(),
  49. addr: Arc::new(RwLock::new(None)),
  50. sender: Arc::new(RwLock::new(None)),
  51. conn_state_notify: Arc::new(RwLock::new(WSConnectStateNotifier::default())),
  52. }
  53. }
  54. }
  55. impl WSController {
  56. pub fn new() -> Self {
  57. WSController::default()
  58. }
  59. pub fn add_ws_message_receiver(&self, handler: Arc<dyn WSMessageReceiver>) -> Result<(), WSError> {
  60. let source = handler.source();
  61. if self.handlers.contains_key(&source) {
  62. log::error!("{:?} is already registered", source);
  63. }
  64. self.handlers.insert(source, handler);
  65. Ok(())
  66. }
  67. pub async fn start(&self, addr: String) -> Result<(), WSError> {
  68. *self.addr.write().await = Some(addr.clone());
  69. let strategy = FixedInterval::from_millis(5000).take(3);
  70. self.connect(addr, strategy).await
  71. }
  72. pub async fn stop(&self) {
  73. if self.conn_state_notify.read().await.conn_state.is_connected() {
  74. tracing::trace!("[{}] stop", self);
  75. self.conn_state_notify
  76. .write()
  77. .await
  78. .update_state(WSConnectState::Disconnected);
  79. }
  80. }
  81. async fn connect<T, I>(&self, addr: String, strategy: T) -> Result<(), WSError>
  82. where
  83. T: IntoIterator<IntoIter = I, Item = Duration>,
  84. I: Iterator<Item = Duration> + Send + 'static,
  85. {
  86. let mut conn_state_notify = self.conn_state_notify.write().await;
  87. let conn_state = conn_state_notify.conn_state.clone();
  88. if conn_state.is_connected() || conn_state.is_connecting() {
  89. return Ok(());
  90. }
  91. let (ret, rx) = oneshot::channel::<Result<(), WSError>>();
  92. *self.addr.write().await = Some(addr.clone());
  93. let action = WSConnectAction {
  94. addr,
  95. handlers: self.handlers.clone(),
  96. };
  97. let retry = Retry::spawn(strategy, action);
  98. conn_state_notify.update_state(WSConnectState::Connecting);
  99. drop(conn_state_notify);
  100. let cloned_conn_state = self.conn_state_notify.clone();
  101. let cloned_sender = self.sender.clone();
  102. tracing::trace!("[{}] start connecting", self);
  103. tokio::spawn(async move {
  104. match retry.await {
  105. Ok(result) => {
  106. let WSConnectResult {
  107. stream,
  108. handlers_fut,
  109. sender,
  110. } = result;
  111. cloned_conn_state.write().await.update_state(WSConnectState::Connected);
  112. *cloned_sender.write().await = Some(Arc::new(sender));
  113. let _ = ret.send(Ok(()));
  114. spawn_stream_and_handlers(stream, handlers_fut).await;
  115. }
  116. Err(e) => {
  117. cloned_conn_state
  118. .write()
  119. .await
  120. .update_state(WSConnectState::Disconnected);
  121. let _ = ret.send(Err(WSError::internal().context(e)));
  122. }
  123. }
  124. });
  125. rx.await?
  126. }
  127. pub async fn retry(&self, count: usize) -> Result<(), WSError> {
  128. if !self.conn_state_notify.read().await.conn_state.is_disconnected() {
  129. return Ok(());
  130. }
  131. tracing::trace!("[WebSocket]: retry connect...");
  132. let strategy = FixedInterval::from_millis(5000).take(count);
  133. let addr = self
  134. .addr
  135. .read()
  136. .await
  137. .as_ref()
  138. .expect("Retry web socket connection failed, should call start_connect first")
  139. .clone();
  140. self.connect(addr, strategy).await
  141. }
  142. pub async fn subscribe_state(&self) -> broadcast::Receiver<WSConnectState> {
  143. self.conn_state_notify.read().await.notify.subscribe()
  144. }
  145. pub async fn ws_message_sender(&self) -> Result<Option<Arc<WSSender>>, WSError> {
  146. let sender = self.sender.read().await.clone();
  147. match sender {
  148. None => match self.conn_state_notify.read().await.conn_state {
  149. WSConnectState::Disconnected => {
  150. let msg = "WebSocket is disconnected";
  151. Err(WSError::internal().context(msg))
  152. }
  153. _ => Ok(None),
  154. },
  155. Some(sender) => Ok(Some(sender)),
  156. }
  157. }
  158. }
  159. async fn spawn_stream_and_handlers(stream: WSStream, handlers: WSHandlerFuture) {
  160. tokio::select! {
  161. result = stream => {
  162. if let Err(e) = result {
  163. tracing::error!("WSStream error: {:?}", e);
  164. }
  165. },
  166. result = handlers => tracing::debug!("handlers completed {:?}", result),
  167. };
  168. }
  169. #[pin_project]
  170. pub struct WSHandlerFuture {
  171. #[pin]
  172. msg_rx: MsgReceiver,
  173. handlers: Handlers,
  174. }
  175. impl WSHandlerFuture {
  176. fn new(handlers: Handlers, msg_rx: MsgReceiver) -> Self {
  177. Self { msg_rx, handlers }
  178. }
  179. fn handler_ws_message(&self, message: Message) {
  180. if let Message::Binary(bytes) = message {
  181. self.handle_binary_message(bytes)
  182. }
  183. }
  184. fn handle_binary_message(&self, bytes: Vec<u8>) {
  185. let bytes = Bytes::from(bytes);
  186. match WebSocketRawMessage::try_from(bytes) {
  187. Ok(message) => match self.handlers.get(&message.channel) {
  188. None => log::error!("Can't find any handler for message: {:?}", message),
  189. Some(handler) => handler.receive_message(message.clone()),
  190. },
  191. Err(e) => {
  192. log::error!("Deserialize binary ws message failed: {:?}", e);
  193. }
  194. }
  195. }
  196. }
  197. impl Future for WSHandlerFuture {
  198. type Output = ();
  199. fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  200. loop {
  201. match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
  202. None => {
  203. return Poll::Ready(());
  204. }
  205. Some(message) => self.handler_ws_message(message),
  206. }
  207. }
  208. }
  209. }
  210. #[derive(Debug, Clone)]
  211. pub struct WSSender(MsgSender);
  212. impl WSSender {
  213. pub fn send_msg<T: Into<WebSocketRawMessage>>(&self, msg: T) -> Result<(), WSError> {
  214. let msg = msg.into();
  215. let _ = self
  216. .0
  217. .unbounded_send(msg.into())
  218. .map_err(|e| WSError::internal().context(e))?;
  219. Ok(())
  220. }
  221. pub fn send_text(&self, source: &WSChannel, text: &str) -> Result<(), WSError> {
  222. let msg = WebSocketRawMessage {
  223. channel: source.clone(),
  224. data: text.as_bytes().to_vec(),
  225. };
  226. self.send_msg(msg)
  227. }
  228. pub fn send_binary(&self, source: &WSChannel, bytes: Vec<u8>) -> Result<(), WSError> {
  229. let msg = WebSocketRawMessage {
  230. channel: source.clone(),
  231. data: bytes,
  232. };
  233. self.send_msg(msg)
  234. }
  235. pub fn send_disconnect(&self, reason: &str) -> Result<(), WSError> {
  236. let frame = CloseFrame {
  237. code: CloseCode::Normal,
  238. reason: reason.to_owned().into(),
  239. };
  240. let msg = Message::Close(Some(frame));
  241. let _ = self.0.unbounded_send(msg).map_err(|e| WSError::internal().context(e))?;
  242. Ok(())
  243. }
  244. }
  245. struct WSConnectAction {
  246. addr: String,
  247. handlers: Handlers,
  248. }
  249. impl Action for WSConnectAction {
  250. type Future = Pin<Box<dyn Future<Output = Result<Self::Item, Self::Error>> + Send + Sync>>;
  251. type Item = WSConnectResult;
  252. type Error = WSError;
  253. fn run(&mut self) -> Self::Future {
  254. let addr = self.addr.clone();
  255. let handlers = self.handlers.clone();
  256. Box::pin(WSConnectActionFut::new(addr, handlers))
  257. }
  258. }
  259. struct WSConnectResult {
  260. stream: WSStream,
  261. handlers_fut: WSHandlerFuture,
  262. sender: WSSender,
  263. }
  264. #[pin_project]
  265. struct WSConnectActionFut {
  266. addr: String,
  267. #[pin]
  268. conn: WSConnectionFuture,
  269. handlers_fut: Option<WSHandlerFuture>,
  270. sender: Option<WSSender>,
  271. }
  272. impl WSConnectActionFut {
  273. fn new(addr: String, handlers: Handlers) -> Self {
  274. // Stream User
  275. // ┌───────────────┐ ┌──────────────┐
  276. // ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
  277. // │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │ │
  278. // └──────┘ │ └─────────┘ │ └────────┘ │ └────────┘ │
  279. // ▲ │ │ │ │
  280. // │ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
  281. // └─────────┼──│ws_write │◀─┼────│ ws_rx │◀──┼──│ ws_tx │ │
  282. // │ └─────────┘ │ └────────┘ │ └────────┘ │
  283. // └───────────────┘ └──────────────┘
  284. let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
  285. let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
  286. let sender = WSSender(ws_tx);
  287. let handlers_fut = WSHandlerFuture::new(handlers, msg_rx);
  288. let conn = WSConnectionFuture::new(msg_tx, ws_rx, addr.clone());
  289. Self {
  290. addr,
  291. conn,
  292. handlers_fut: Some(handlers_fut),
  293. sender: Some(sender),
  294. }
  295. }
  296. }
  297. impl Future for WSConnectActionFut {
  298. type Output = Result<WSConnectResult, WSError>;
  299. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  300. let mut this = self.project();
  301. match ready!(this.conn.as_mut().poll(cx)) {
  302. Ok(stream) => {
  303. let handlers_fut = this.handlers_fut.take().expect("Only take once");
  304. let sender = this.sender.take().expect("Only take once");
  305. Poll::Ready(Ok(WSConnectResult {
  306. stream,
  307. handlers_fut,
  308. sender,
  309. }))
  310. }
  311. Err(e) => Poll::Ready(Err(e)),
  312. }
  313. }
  314. }
  315. #[derive(Clone, Eq, PartialEq)]
  316. pub enum WSConnectState {
  317. Init,
  318. Connecting,
  319. Connected,
  320. Disconnected,
  321. }
  322. impl WSConnectState {
  323. fn is_connected(&self) -> bool {
  324. self == &WSConnectState::Connected
  325. }
  326. fn is_connecting(&self) -> bool {
  327. self == &WSConnectState::Connecting
  328. }
  329. fn is_disconnected(&self) -> bool {
  330. self == &WSConnectState::Disconnected || self == &WSConnectState::Init
  331. }
  332. }
  333. impl std::fmt::Display for WSConnectState {
  334. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  335. match self {
  336. WSConnectState::Init => f.write_str("Init"),
  337. WSConnectState::Connected => f.write_str("Connected"),
  338. WSConnectState::Connecting => f.write_str("Connecting"),
  339. WSConnectState::Disconnected => f.write_str("Disconnected"),
  340. }
  341. }
  342. }
  343. impl std::fmt::Debug for WSConnectState {
  344. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  345. f.write_str(&format!("{}", self))
  346. }
  347. }
  348. struct WSConnectStateNotifier {
  349. conn_state: WSConnectState,
  350. notify: Arc<broadcast::Sender<WSConnectState>>,
  351. }
  352. impl std::default::Default for WSConnectStateNotifier {
  353. fn default() -> Self {
  354. let (state_notify, _) = broadcast::channel(16);
  355. Self {
  356. conn_state: WSConnectState::Init,
  357. notify: Arc::new(state_notify),
  358. }
  359. }
  360. }
  361. impl WSConnectStateNotifier {
  362. fn update_state(&mut self, new_state: WSConnectState) {
  363. if self.conn_state == new_state {
  364. return;
  365. }
  366. tracing::debug!(
  367. "WebSocket connect state did change: {} -> {}",
  368. self.conn_state,
  369. new_state
  370. );
  371. self.conn_state = new_state.clone();
  372. let _ = self.notify.send(new_state);
  373. }
  374. }