ws_manager.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. use crate::ConflictRevisionSink;
  2. use async_stream::stream;
  3. use flowy_error::{FlowyError, FlowyResult};
  4. use futures_util::{future::BoxFuture, stream::StreamExt};
  5. use lib_infra::future::{BoxResultFuture, FutureResult};
  6. use lib_ws::WSConnectState;
  7. use revision_model::{Revision, RevisionRange};
  8. use std::{collections::VecDeque, fmt::Formatter, sync::Arc};
  9. use tokio::{
  10. sync::{
  11. broadcast, mpsc,
  12. mpsc::{Receiver, Sender},
  13. RwLock,
  14. },
  15. time::{interval, Duration},
  16. };
  17. use ws_model::ws_revision::{
  18. ClientRevisionWSData, NewDocumentUser, ServerRevisionWSData, WSRevisionPayload,
  19. };
  20. // The consumer consumes the messages pushed by the web socket.
  21. pub trait RevisionWSDataStream: Send + Sync {
  22. fn receive_push_revision(&self, revisions: Vec<Revision>) -> BoxResultFuture<(), FlowyError>;
  23. fn receive_ack(&self, rev_id: i64) -> BoxResultFuture<(), FlowyError>;
  24. fn receive_new_user_connect(&self, new_user: NewDocumentUser) -> BoxResultFuture<(), FlowyError>;
  25. fn pull_revisions_in_range(&self, range: RevisionRange) -> BoxResultFuture<(), FlowyError>;
  26. }
  27. // The sink provides the data that will be sent through the web socket to the
  28. // server.
  29. pub trait RevisionWebSocketSink: Send + Sync {
  30. fn next(&self) -> FutureResult<Option<ClientRevisionWSData>, FlowyError>;
  31. }
  32. pub type WSStateReceiver = tokio::sync::broadcast::Receiver<WSConnectState>;
  33. pub trait RevisionWebSocket: Send + Sync + 'static {
  34. fn send(&self, data: ClientRevisionWSData) -> BoxResultFuture<(), FlowyError>;
  35. fn subscribe_state_changed(&self) -> BoxFuture<WSStateReceiver>;
  36. }
  37. pub struct RevisionWebSocketManager {
  38. pub object_name: String,
  39. pub object_id: String,
  40. ws_data_sink: Arc<dyn RevisionWebSocketSink>,
  41. ws_data_stream: Arc<dyn RevisionWSDataStream>,
  42. rev_web_socket: Arc<dyn RevisionWebSocket>,
  43. pub ws_passthrough_tx: Sender<ServerRevisionWSData>,
  44. ws_passthrough_rx: Option<Receiver<ServerRevisionWSData>>,
  45. pub state_passthrough_tx: broadcast::Sender<WSConnectState>,
  46. stop_sync_tx: SinkStopTx,
  47. }
  48. impl std::fmt::Display for RevisionWebSocketManager {
  49. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  50. f.write_fmt(format_args!("{}RevisionWebSocketManager", self.object_name))
  51. }
  52. }
  53. impl RevisionWebSocketManager {
  54. pub fn new(
  55. object_name: &str,
  56. object_id: &str,
  57. rev_web_socket: Arc<dyn RevisionWebSocket>,
  58. ws_data_sink: Arc<dyn RevisionWebSocketSink>,
  59. ws_data_stream: Arc<dyn RevisionWSDataStream>,
  60. ping_duration: Duration,
  61. ) -> Self {
  62. let (ws_passthrough_tx, ws_passthrough_rx) = mpsc::channel(1000);
  63. let (stop_sync_tx, _) = tokio::sync::broadcast::channel(2);
  64. let object_id = object_id.to_string();
  65. let object_name = object_name.to_string();
  66. let (state_passthrough_tx, _) = broadcast::channel(2);
  67. let mut manager = RevisionWebSocketManager {
  68. object_id,
  69. object_name,
  70. ws_data_sink,
  71. ws_data_stream,
  72. rev_web_socket,
  73. ws_passthrough_tx,
  74. ws_passthrough_rx: Some(ws_passthrough_rx),
  75. state_passthrough_tx,
  76. stop_sync_tx,
  77. };
  78. manager.run(ping_duration);
  79. manager
  80. }
  81. fn run(&mut self, ping_duration: Duration) {
  82. let ws_passthrough_rx = self.ws_passthrough_rx.take().expect("Only take once");
  83. let sink = RevisionWSSink::new(
  84. &self.object_id,
  85. &self.object_name,
  86. self.ws_data_sink.clone(),
  87. self.rev_web_socket.clone(),
  88. self.stop_sync_tx.subscribe(),
  89. ping_duration,
  90. );
  91. let stream = RevisionWSStream::new(
  92. &self.object_name,
  93. &self.object_id,
  94. self.ws_data_stream.clone(),
  95. ws_passthrough_rx,
  96. self.stop_sync_tx.subscribe(),
  97. );
  98. tokio::spawn(sink.run());
  99. tokio::spawn(stream.run());
  100. }
  101. pub fn scribe_state(&self) -> broadcast::Receiver<WSConnectState> {
  102. self.state_passthrough_tx.subscribe()
  103. }
  104. pub fn stop(&self) {
  105. if self.stop_sync_tx.send(()).is_ok() {
  106. tracing::trace!("{} stop sync", self.object_id)
  107. }
  108. }
  109. #[tracing::instrument(level = "debug", skip(self, data), err)]
  110. pub async fn receive_ws_data(&self, data: ServerRevisionWSData) -> Result<(), FlowyError> {
  111. self.ws_passthrough_tx.send(data).await.map_err(|e| {
  112. let err_msg = format!("{} passthrough error: {}", self.object_id, e);
  113. FlowyError::internal().context(err_msg)
  114. })?;
  115. Ok(())
  116. }
  117. pub fn connect_state_changed(&self, state: WSConnectState) {
  118. match self.state_passthrough_tx.send(state) {
  119. Ok(_) => {},
  120. Err(e) => tracing::error!("{}", e),
  121. }
  122. }
  123. }
  124. impl std::ops::Drop for RevisionWebSocketManager {
  125. fn drop(&mut self) {
  126. tracing::trace!("{} was dropped", self)
  127. }
  128. }
  129. pub struct RevisionWSStream {
  130. object_name: String,
  131. object_id: String,
  132. consumer: Arc<dyn RevisionWSDataStream>,
  133. ws_msg_rx: Option<mpsc::Receiver<ServerRevisionWSData>>,
  134. stop_rx: Option<SinkStopRx>,
  135. }
  136. impl std::fmt::Display for RevisionWSStream {
  137. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  138. f.write_fmt(format_args!("{}RevisionWSStream", self.object_name))
  139. }
  140. }
  141. impl std::ops::Drop for RevisionWSStream {
  142. fn drop(&mut self) {
  143. tracing::trace!("{} was dropped", self)
  144. }
  145. }
  146. impl RevisionWSStream {
  147. pub fn new(
  148. object_name: &str,
  149. object_id: &str,
  150. consumer: Arc<dyn RevisionWSDataStream>,
  151. ws_msg_rx: mpsc::Receiver<ServerRevisionWSData>,
  152. stop_rx: SinkStopRx,
  153. ) -> Self {
  154. RevisionWSStream {
  155. object_name: object_name.to_string(),
  156. object_id: object_id.to_owned(),
  157. consumer,
  158. ws_msg_rx: Some(ws_msg_rx),
  159. stop_rx: Some(stop_rx),
  160. }
  161. }
  162. pub async fn run(mut self) {
  163. let mut receiver = self.ws_msg_rx.take().expect("Only take once");
  164. let mut stop_rx = self.stop_rx.take().expect("Only take once");
  165. let object_id = self.object_id.clone();
  166. let name = format!("{}", &self);
  167. let stream = stream! {
  168. loop {
  169. tokio::select! {
  170. result = receiver.recv() => {
  171. match result {
  172. Some(msg) => {
  173. yield msg
  174. },
  175. None => {
  176. tracing::debug!("[{}]:{} loop exit", name, object_id);
  177. break;
  178. },
  179. }
  180. },
  181. _ = stop_rx.recv() => {
  182. tracing::debug!("[{}]:{} loop exit", name, object_id);
  183. break
  184. },
  185. };
  186. }
  187. };
  188. stream
  189. .for_each(|msg| async {
  190. match self.handle_message(msg).await {
  191. Ok(_) => {},
  192. Err(e) => tracing::error!("[{}]:{} error: {}", &self, self.object_id, e),
  193. }
  194. })
  195. .await;
  196. }
  197. async fn handle_message(&self, msg: ServerRevisionWSData) -> FlowyResult<()> {
  198. let ServerRevisionWSData { object_id, payload } = msg;
  199. match payload {
  200. WSRevisionPayload::ServerPushRev { revisions } => {
  201. tracing::trace!("[{}]: new push revision: {}", self, object_id);
  202. self.consumer.receive_push_revision(revisions).await?;
  203. },
  204. WSRevisionPayload::ServerPullRev { range } => {
  205. tracing::trace!("[{}]: new pull: {}:{:?}", self, object_id, range);
  206. self.consumer.pull_revisions_in_range(range).await?;
  207. },
  208. WSRevisionPayload::ServerAck { rev_id } => {
  209. tracing::trace!("[{}]: new ack: {}:{}", self, object_id, rev_id);
  210. let _ = self.consumer.receive_ack(rev_id).await;
  211. },
  212. WSRevisionPayload::UserConnect { user } => {
  213. let _ = self.consumer.receive_new_user_connect(user).await;
  214. },
  215. }
  216. Ok(())
  217. }
  218. }
  219. type SinkStopRx = broadcast::Receiver<()>;
  220. type SinkStopTx = broadcast::Sender<()>;
  221. pub struct RevisionWSSink {
  222. object_id: String,
  223. object_name: String,
  224. provider: Arc<dyn RevisionWebSocketSink>,
  225. rev_web_socket: Arc<dyn RevisionWebSocket>,
  226. stop_rx: Option<SinkStopRx>,
  227. ping_duration: Duration,
  228. }
  229. impl RevisionWSSink {
  230. pub fn new(
  231. object_id: &str,
  232. object_name: &str,
  233. provider: Arc<dyn RevisionWebSocketSink>,
  234. rev_web_socket: Arc<dyn RevisionWebSocket>,
  235. stop_rx: SinkStopRx,
  236. ping_duration: Duration,
  237. ) -> Self {
  238. Self {
  239. object_id: object_id.to_owned(),
  240. object_name: object_name.to_owned(),
  241. provider,
  242. rev_web_socket,
  243. stop_rx: Some(stop_rx),
  244. ping_duration,
  245. }
  246. }
  247. pub async fn run(mut self) {
  248. let (tx, mut rx) = mpsc::channel(1);
  249. let mut stop_rx = self.stop_rx.take().expect("Only take once");
  250. let object_id = self.object_id.clone();
  251. tokio::spawn(tick(tx, self.ping_duration));
  252. let name = format!("{}", self);
  253. let stream = stream! {
  254. loop {
  255. tokio::select! {
  256. result = rx.recv() => {
  257. match result {
  258. Some(msg) => yield msg,
  259. None => break,
  260. }
  261. },
  262. _ = stop_rx.recv() => {
  263. tracing::trace!("[{}]:{} loop exit", name, object_id);
  264. break
  265. },
  266. };
  267. }
  268. };
  269. stream
  270. .for_each(|_| async {
  271. match self.send_next_revision().await {
  272. Ok(_) => {},
  273. Err(e) => tracing::error!("[{}] send failed, {:?}", self, e),
  274. }
  275. })
  276. .await;
  277. }
  278. async fn send_next_revision(&self) -> FlowyResult<()> {
  279. match self.provider.next().await? {
  280. None => {
  281. tracing::trace!("[{}]: Finish synchronizing revisions", self);
  282. Ok(())
  283. },
  284. Some(data) => {
  285. tracing::trace!(
  286. "[{}]: send {}:{}-{:?}",
  287. self,
  288. data.object_id,
  289. data.rev_id,
  290. data.ty
  291. );
  292. self.rev_web_socket.send(data).await
  293. },
  294. }
  295. }
  296. }
  297. async fn tick(sender: mpsc::Sender<()>, duration: Duration) {
  298. let mut interval = interval(duration);
  299. while sender.send(()).await.is_ok() {
  300. interval.tick().await;
  301. }
  302. }
  303. impl std::fmt::Display for RevisionWSSink {
  304. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  305. f.write_fmt(format_args!("{}RevisionWSSink", self.object_name))
  306. }
  307. }
  308. impl std::ops::Drop for RevisionWSSink {
  309. fn drop(&mut self) {
  310. tracing::trace!("{} was dropped", self)
  311. }
  312. }
  313. #[derive(Clone)]
  314. enum Source {
  315. Custom,
  316. Revision,
  317. }
  318. pub trait WSDataProviderDataSource: Send + Sync {
  319. fn next_revision(&self) -> FutureResult<Option<Revision>, FlowyError>;
  320. fn ack_revision(&self, rev_id: i64) -> FutureResult<(), FlowyError>;
  321. fn current_rev_id(&self) -> i64;
  322. }
  323. #[derive(Clone)]
  324. pub struct WSDataProvider {
  325. object_id: String,
  326. rev_ws_data_list: Arc<RwLock<VecDeque<ClientRevisionWSData>>>,
  327. data_source: Arc<dyn WSDataProviderDataSource>,
  328. current_source: Arc<RwLock<Source>>,
  329. }
  330. impl WSDataProvider {
  331. pub fn new(object_id: &str, data_source: Arc<dyn WSDataProviderDataSource>) -> Self {
  332. WSDataProvider {
  333. object_id: object_id.to_owned(),
  334. rev_ws_data_list: Arc::new(RwLock::new(VecDeque::new())),
  335. data_source,
  336. current_source: Arc::new(RwLock::new(Source::Custom)),
  337. }
  338. }
  339. pub async fn push_data(&self, data: ClientRevisionWSData) {
  340. self.rev_ws_data_list.write().await.push_back(data);
  341. }
  342. pub async fn next(&self) -> FlowyResult<Option<ClientRevisionWSData>> {
  343. let source = self.current_source.read().await.clone();
  344. let data = match source {
  345. Source::Custom => match self.rev_ws_data_list.read().await.front() {
  346. None => {
  347. *self.current_source.write().await = Source::Revision;
  348. Ok(None)
  349. },
  350. Some(data) => Ok(Some(data.clone())),
  351. },
  352. Source::Revision => {
  353. if !self.rev_ws_data_list.read().await.is_empty() {
  354. *self.current_source.write().await = Source::Custom;
  355. return Ok(None);
  356. }
  357. match self.data_source.next_revision().await? {
  358. Some(rev) => Ok(Some(ClientRevisionWSData::from_revisions(
  359. &self.object_id,
  360. vec![rev],
  361. ))),
  362. None => Ok(Some(ClientRevisionWSData::ping(
  363. &self.object_id,
  364. self.data_source.current_rev_id(),
  365. ))),
  366. }
  367. },
  368. };
  369. data
  370. }
  371. pub async fn ack_data(&self, rev_id: i64) -> FlowyResult<()> {
  372. let source = self.current_source.read().await.clone();
  373. match source {
  374. Source::Custom => {
  375. let should_pop = match self.rev_ws_data_list.read().await.front() {
  376. None => false,
  377. Some(val) => {
  378. if val.rev_id == rev_id {
  379. true
  380. } else {
  381. tracing::error!(
  382. "The front element's {} is not equal to the {}",
  383. val.rev_id,
  384. rev_id
  385. );
  386. false
  387. }
  388. },
  389. };
  390. if should_pop {
  391. let _ = self.rev_ws_data_list.write().await.pop_front();
  392. }
  393. Ok(())
  394. },
  395. Source::Revision => {
  396. self.data_source.ack_revision(rev_id).await?;
  397. Ok::<(), FlowyError>(())
  398. },
  399. }
  400. }
  401. }
  402. impl ConflictRevisionSink for Arc<WSDataProvider> {
  403. fn send(&self, revisions: Vec<Revision>) -> BoxResultFuture<(), FlowyError> {
  404. let sink = self.clone();
  405. Box::pin(async move {
  406. sink
  407. .push_data(ClientRevisionWSData::from_revisions(
  408. &sink.object_id,
  409. revisions,
  410. ))
  411. .await;
  412. Ok(())
  413. })
  414. }
  415. fn ack(&self, rev_id: i64) -> BoxResultFuture<(), FlowyError> {
  416. let sink = self.clone();
  417. Box::pin(async move { sink.ack_data(rev_id).await })
  418. }
  419. }