ws_manager.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. use crate::ConflictRevisionSink;
  2. use async_stream::stream;
  3. use flowy_error::{FlowyError, FlowyResult};
  4. use flowy_http_model::revision::{Revision, RevisionRange};
  5. use flowy_http_model::ws_data::{ClientRevisionWSData, NewDocumentUser, ServerRevisionWSData, WSRevisionPayload};
  6. use futures_util::{future::BoxFuture, stream::StreamExt};
  7. use lib_infra::future::{BoxResultFuture, FutureResult};
  8. use lib_ws::WSConnectState;
  9. use std::{collections::VecDeque, fmt::Formatter, sync::Arc};
  10. use tokio::{
  11. sync::{
  12. broadcast, mpsc,
  13. mpsc::{Receiver, Sender},
  14. RwLock,
  15. },
  16. time::{interval, Duration},
  17. };
  18. // The consumer consumes the messages pushed by the web socket.
  19. pub trait RevisionWSDataStream: Send + Sync {
  20. fn receive_push_revision(&self, revisions: Vec<Revision>) -> BoxResultFuture<(), FlowyError>;
  21. fn receive_ack(&self, rev_id: i64) -> BoxResultFuture<(), FlowyError>;
  22. fn receive_new_user_connect(&self, new_user: NewDocumentUser) -> BoxResultFuture<(), FlowyError>;
  23. fn pull_revisions_in_range(&self, range: RevisionRange) -> BoxResultFuture<(), FlowyError>;
  24. }
  25. // The sink provides the data that will be sent through the web socket to the
  26. // server.
  27. pub trait RevisionWebSocketSink: Send + Sync {
  28. fn next(&self) -> FutureResult<Option<ClientRevisionWSData>, FlowyError>;
  29. }
  30. pub type WSStateReceiver = tokio::sync::broadcast::Receiver<WSConnectState>;
  31. pub trait RevisionWebSocket: Send + Sync + 'static {
  32. fn send(&self, data: ClientRevisionWSData) -> BoxResultFuture<(), FlowyError>;
  33. fn subscribe_state_changed(&self) -> BoxFuture<WSStateReceiver>;
  34. }
  35. pub struct RevisionWebSocketManager {
  36. pub object_name: String,
  37. pub object_id: String,
  38. ws_data_sink: Arc<dyn RevisionWebSocketSink>,
  39. ws_data_stream: Arc<dyn RevisionWSDataStream>,
  40. rev_web_socket: Arc<dyn RevisionWebSocket>,
  41. pub ws_passthrough_tx: Sender<ServerRevisionWSData>,
  42. ws_passthrough_rx: Option<Receiver<ServerRevisionWSData>>,
  43. pub state_passthrough_tx: broadcast::Sender<WSConnectState>,
  44. stop_sync_tx: SinkStopTx,
  45. }
  46. impl std::fmt::Display for RevisionWebSocketManager {
  47. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  48. f.write_fmt(format_args!("{}RevisionWebSocketManager", self.object_name))
  49. }
  50. }
  51. impl RevisionWebSocketManager {
  52. pub fn new(
  53. object_name: &str,
  54. object_id: &str,
  55. rev_web_socket: Arc<dyn RevisionWebSocket>,
  56. ws_data_sink: Arc<dyn RevisionWebSocketSink>,
  57. ws_data_stream: Arc<dyn RevisionWSDataStream>,
  58. ping_duration: Duration,
  59. ) -> Self {
  60. let (ws_passthrough_tx, ws_passthrough_rx) = mpsc::channel(1000);
  61. let (stop_sync_tx, _) = tokio::sync::broadcast::channel(2);
  62. let object_id = object_id.to_string();
  63. let object_name = object_name.to_string();
  64. let (state_passthrough_tx, _) = broadcast::channel(2);
  65. let mut manager = RevisionWebSocketManager {
  66. object_id,
  67. object_name,
  68. ws_data_sink,
  69. ws_data_stream,
  70. rev_web_socket,
  71. ws_passthrough_tx,
  72. ws_passthrough_rx: Some(ws_passthrough_rx),
  73. state_passthrough_tx,
  74. stop_sync_tx,
  75. };
  76. manager.run(ping_duration);
  77. manager
  78. }
  79. fn run(&mut self, ping_duration: Duration) {
  80. let ws_passthrough_rx = self.ws_passthrough_rx.take().expect("Only take once");
  81. let sink = RevisionWSSink::new(
  82. &self.object_id,
  83. &self.object_name,
  84. self.ws_data_sink.clone(),
  85. self.rev_web_socket.clone(),
  86. self.stop_sync_tx.subscribe(),
  87. ping_duration,
  88. );
  89. let stream = RevisionWSStream::new(
  90. &self.object_name,
  91. &self.object_id,
  92. self.ws_data_stream.clone(),
  93. ws_passthrough_rx,
  94. self.stop_sync_tx.subscribe(),
  95. );
  96. tokio::spawn(sink.run());
  97. tokio::spawn(stream.run());
  98. }
  99. pub fn scribe_state(&self) -> broadcast::Receiver<WSConnectState> {
  100. self.state_passthrough_tx.subscribe()
  101. }
  102. pub fn stop(&self) {
  103. if self.stop_sync_tx.send(()).is_ok() {
  104. tracing::trace!("{} stop sync", self.object_id)
  105. }
  106. }
  107. #[tracing::instrument(level = "debug", skip(self, data), err)]
  108. pub async fn receive_ws_data(&self, data: ServerRevisionWSData) -> Result<(), FlowyError> {
  109. self.ws_passthrough_tx.send(data).await.map_err(|e| {
  110. let err_msg = format!("{} passthrough error: {}", self.object_id, e);
  111. FlowyError::internal().context(err_msg)
  112. })?;
  113. Ok(())
  114. }
  115. pub fn connect_state_changed(&self, state: WSConnectState) {
  116. match self.state_passthrough_tx.send(state) {
  117. Ok(_) => {}
  118. Err(e) => tracing::error!("{}", e),
  119. }
  120. }
  121. }
  122. impl std::ops::Drop for RevisionWebSocketManager {
  123. fn drop(&mut self) {
  124. tracing::trace!("{} was dropped", self)
  125. }
  126. }
  127. pub struct RevisionWSStream {
  128. object_name: String,
  129. object_id: String,
  130. consumer: Arc<dyn RevisionWSDataStream>,
  131. ws_msg_rx: Option<mpsc::Receiver<ServerRevisionWSData>>,
  132. stop_rx: Option<SinkStopRx>,
  133. }
  134. impl std::fmt::Display for RevisionWSStream {
  135. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  136. f.write_fmt(format_args!("{}RevisionWSStream", self.object_name))
  137. }
  138. }
  139. impl std::ops::Drop for RevisionWSStream {
  140. fn drop(&mut self) {
  141. tracing::trace!("{} was dropped", self)
  142. }
  143. }
  144. impl RevisionWSStream {
  145. pub fn new(
  146. object_name: &str,
  147. object_id: &str,
  148. consumer: Arc<dyn RevisionWSDataStream>,
  149. ws_msg_rx: mpsc::Receiver<ServerRevisionWSData>,
  150. stop_rx: SinkStopRx,
  151. ) -> Self {
  152. RevisionWSStream {
  153. object_name: object_name.to_string(),
  154. object_id: object_id.to_owned(),
  155. consumer,
  156. ws_msg_rx: Some(ws_msg_rx),
  157. stop_rx: Some(stop_rx),
  158. }
  159. }
  160. pub async fn run(mut self) {
  161. let mut receiver = self.ws_msg_rx.take().expect("Only take once");
  162. let mut stop_rx = self.stop_rx.take().expect("Only take once");
  163. let object_id = self.object_id.clone();
  164. let name = format!("{}", &self);
  165. let stream = stream! {
  166. loop {
  167. tokio::select! {
  168. result = receiver.recv() => {
  169. match result {
  170. Some(msg) => {
  171. yield msg
  172. },
  173. None => {
  174. tracing::debug!("[{}]:{} loop exit", name, object_id);
  175. break;
  176. },
  177. }
  178. },
  179. _ = stop_rx.recv() => {
  180. tracing::debug!("[{}]:{} loop exit", name, object_id);
  181. break
  182. },
  183. };
  184. }
  185. };
  186. stream
  187. .for_each(|msg| async {
  188. match self.handle_message(msg).await {
  189. Ok(_) => {}
  190. Err(e) => tracing::error!("[{}]:{} error: {}", &self, self.object_id, e),
  191. }
  192. })
  193. .await;
  194. }
  195. async fn handle_message(&self, msg: ServerRevisionWSData) -> FlowyResult<()> {
  196. let ServerRevisionWSData { object_id, payload } = msg;
  197. match payload {
  198. WSRevisionPayload::ServerPushRev { revisions } => {
  199. tracing::trace!("[{}]: new push revision: {}", self, object_id);
  200. self.consumer.receive_push_revision(revisions).await?;
  201. }
  202. WSRevisionPayload::ServerPullRev { range } => {
  203. tracing::trace!("[{}]: new pull: {}:{:?}", self, object_id, range);
  204. self.consumer.pull_revisions_in_range(range).await?;
  205. }
  206. WSRevisionPayload::ServerAck { rev_id } => {
  207. tracing::trace!("[{}]: new ack: {}:{}", self, object_id, rev_id);
  208. let _ = self.consumer.receive_ack(rev_id).await;
  209. }
  210. WSRevisionPayload::UserConnect { user } => {
  211. let _ = self.consumer.receive_new_user_connect(user).await;
  212. }
  213. }
  214. Ok(())
  215. }
  216. }
  217. type SinkStopRx = broadcast::Receiver<()>;
  218. type SinkStopTx = broadcast::Sender<()>;
  219. pub struct RevisionWSSink {
  220. object_id: String,
  221. object_name: String,
  222. provider: Arc<dyn RevisionWebSocketSink>,
  223. rev_web_socket: Arc<dyn RevisionWebSocket>,
  224. stop_rx: Option<SinkStopRx>,
  225. ping_duration: Duration,
  226. }
  227. impl RevisionWSSink {
  228. pub fn new(
  229. object_id: &str,
  230. object_name: &str,
  231. provider: Arc<dyn RevisionWebSocketSink>,
  232. rev_web_socket: Arc<dyn RevisionWebSocket>,
  233. stop_rx: SinkStopRx,
  234. ping_duration: Duration,
  235. ) -> Self {
  236. Self {
  237. object_id: object_id.to_owned(),
  238. object_name: object_name.to_owned(),
  239. provider,
  240. rev_web_socket,
  241. stop_rx: Some(stop_rx),
  242. ping_duration,
  243. }
  244. }
  245. pub async fn run(mut self) {
  246. let (tx, mut rx) = mpsc::channel(1);
  247. let mut stop_rx = self.stop_rx.take().expect("Only take once");
  248. let object_id = self.object_id.clone();
  249. tokio::spawn(tick(tx, self.ping_duration));
  250. let name = format!("{}", self);
  251. let stream = stream! {
  252. loop {
  253. tokio::select! {
  254. result = rx.recv() => {
  255. match result {
  256. Some(msg) => yield msg,
  257. None => break,
  258. }
  259. },
  260. _ = stop_rx.recv() => {
  261. tracing::trace!("[{}]:{} loop exit", name, object_id);
  262. break
  263. },
  264. };
  265. }
  266. };
  267. stream
  268. .for_each(|_| async {
  269. match self.send_next_revision().await {
  270. Ok(_) => {}
  271. Err(e) => tracing::error!("[{}] send failed, {:?}", self, e),
  272. }
  273. })
  274. .await;
  275. }
  276. async fn send_next_revision(&self) -> FlowyResult<()> {
  277. match self.provider.next().await? {
  278. None => {
  279. tracing::trace!("[{}]: Finish synchronizing revisions", self);
  280. Ok(())
  281. }
  282. Some(data) => {
  283. tracing::trace!("[{}]: send {}:{}-{:?}", self, data.object_id, data.rev_id, data.ty);
  284. self.rev_web_socket.send(data).await
  285. }
  286. }
  287. }
  288. }
  289. async fn tick(sender: mpsc::Sender<()>, duration: Duration) {
  290. let mut interval = interval(duration);
  291. while sender.send(()).await.is_ok() {
  292. interval.tick().await;
  293. }
  294. }
  295. impl std::fmt::Display for RevisionWSSink {
  296. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  297. f.write_fmt(format_args!("{}RevisionWSSink", self.object_name))
  298. }
  299. }
  300. impl std::ops::Drop for RevisionWSSink {
  301. fn drop(&mut self) {
  302. tracing::trace!("{} was dropped", self)
  303. }
  304. }
  305. #[derive(Clone)]
  306. enum Source {
  307. Custom,
  308. Revision,
  309. }
  310. pub trait WSDataProviderDataSource: Send + Sync {
  311. fn next_revision(&self) -> FutureResult<Option<Revision>, FlowyError>;
  312. fn ack_revision(&self, rev_id: i64) -> FutureResult<(), FlowyError>;
  313. fn current_rev_id(&self) -> i64;
  314. }
  315. #[derive(Clone)]
  316. pub struct WSDataProvider {
  317. object_id: String,
  318. rev_ws_data_list: Arc<RwLock<VecDeque<ClientRevisionWSData>>>,
  319. data_source: Arc<dyn WSDataProviderDataSource>,
  320. current_source: Arc<RwLock<Source>>,
  321. }
  322. impl WSDataProvider {
  323. pub fn new(object_id: &str, data_source: Arc<dyn WSDataProviderDataSource>) -> Self {
  324. WSDataProvider {
  325. object_id: object_id.to_owned(),
  326. rev_ws_data_list: Arc::new(RwLock::new(VecDeque::new())),
  327. data_source,
  328. current_source: Arc::new(RwLock::new(Source::Custom)),
  329. }
  330. }
  331. pub async fn push_data(&self, data: ClientRevisionWSData) {
  332. self.rev_ws_data_list.write().await.push_back(data);
  333. }
  334. pub async fn next(&self) -> FlowyResult<Option<ClientRevisionWSData>> {
  335. let source = self.current_source.read().await.clone();
  336. let data = match source {
  337. Source::Custom => match self.rev_ws_data_list.read().await.front() {
  338. None => {
  339. *self.current_source.write().await = Source::Revision;
  340. Ok(None)
  341. }
  342. Some(data) => Ok(Some(data.clone())),
  343. },
  344. Source::Revision => {
  345. if !self.rev_ws_data_list.read().await.is_empty() {
  346. *self.current_source.write().await = Source::Custom;
  347. return Ok(None);
  348. }
  349. match self.data_source.next_revision().await? {
  350. Some(rev) => Ok(Some(ClientRevisionWSData::from_revisions(&self.object_id, vec![rev]))),
  351. None => Ok(Some(ClientRevisionWSData::ping(
  352. &self.object_id,
  353. self.data_source.current_rev_id(),
  354. ))),
  355. }
  356. }
  357. };
  358. data
  359. }
  360. pub async fn ack_data(&self, rev_id: i64) -> FlowyResult<()> {
  361. let source = self.current_source.read().await.clone();
  362. match source {
  363. Source::Custom => {
  364. let should_pop = match self.rev_ws_data_list.read().await.front() {
  365. None => false,
  366. Some(val) => {
  367. if val.rev_id == rev_id {
  368. true
  369. } else {
  370. tracing::error!("The front element's {} is not equal to the {}", val.rev_id, rev_id);
  371. false
  372. }
  373. }
  374. };
  375. if should_pop {
  376. let _ = self.rev_ws_data_list.write().await.pop_front();
  377. }
  378. Ok(())
  379. }
  380. Source::Revision => {
  381. self.data_source.ack_revision(rev_id).await?;
  382. Ok::<(), FlowyError>(())
  383. }
  384. }
  385. }
  386. }
  387. impl ConflictRevisionSink for Arc<WSDataProvider> {
  388. fn send(&self, revisions: Vec<Revision>) -> BoxResultFuture<(), FlowyError> {
  389. let sink = self.clone();
  390. Box::pin(async move {
  391. sink.push_data(ClientRevisionWSData::from_revisions(&sink.object_id, revisions))
  392. .await;
  393. Ok(())
  394. })
  395. }
  396. fn ack(&self, rev_id: i64) -> BoxResultFuture<(), FlowyError> {
  397. let sink = self.clone();
  398. Box::pin(async move { sink.ack_data(rev_id).await })
  399. }
  400. }