ws_manager.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. use crate::ConflictRevisionSink;
  2. use async_stream::stream;
  3. use bytes::Bytes;
  4. use flowy_collaboration::entities::{
  5. revision::{RevId, Revision, RevisionRange},
  6. ws_data::{ClientRevisionWSData, NewDocumentUser, ServerRevisionWSData, ServerRevisionWSDataType},
  7. };
  8. use flowy_error::{FlowyError, FlowyResult};
  9. use futures_util::{future::BoxFuture, stream::StreamExt};
  10. use lib_infra::future::{BoxResultFuture, FutureResult};
  11. use lib_ws::WSConnectState;
  12. use std::{collections::VecDeque, convert::TryFrom, fmt::Formatter, sync::Arc};
  13. use tokio::{
  14. sync::{
  15. broadcast, mpsc,
  16. mpsc::{Receiver, Sender},
  17. RwLock,
  18. },
  19. time::{interval, Duration},
  20. };
  21. // The consumer consumes the messages pushed by the web socket.
  22. pub trait RevisionWSDataStream: Send + Sync {
  23. fn receive_push_revision(&self, bytes: Bytes) -> BoxResultFuture<(), FlowyError>;
  24. fn receive_ack(&self, id: String, ty: ServerRevisionWSDataType) -> BoxResultFuture<(), FlowyError>;
  25. fn receive_new_user_connect(&self, new_user: NewDocumentUser) -> BoxResultFuture<(), FlowyError>;
  26. fn pull_revisions_in_range(&self, range: RevisionRange) -> BoxResultFuture<(), FlowyError>;
  27. }
  28. // The sink provides the data that will be sent through the web socket to the
  29. // backend.
  30. pub trait RevisionWebSocketSink: Send + Sync {
  31. fn next(&self) -> FutureResult<Option<ClientRevisionWSData>, FlowyError>;
  32. }
  33. pub type WSStateReceiver = tokio::sync::broadcast::Receiver<WSConnectState>;
  34. pub trait RevisionWebSocket: Send + Sync + 'static {
  35. fn send(&self, data: ClientRevisionWSData) -> BoxResultFuture<(), FlowyError>;
  36. fn subscribe_state_changed(&self) -> BoxFuture<WSStateReceiver>;
  37. }
  38. pub struct RevisionWebSocketManager {
  39. pub object_name: String,
  40. pub object_id: String,
  41. ws_data_sink: Arc<dyn RevisionWebSocketSink>,
  42. ws_data_stream: Arc<dyn RevisionWSDataStream>,
  43. rev_web_socket: Arc<dyn RevisionWebSocket>,
  44. pub ws_passthrough_tx: Sender<ServerRevisionWSData>,
  45. ws_passthrough_rx: Option<Receiver<ServerRevisionWSData>>,
  46. pub state_passthrough_tx: broadcast::Sender<WSConnectState>,
  47. stop_sync_tx: SinkStopTx,
  48. }
  49. impl std::fmt::Display for RevisionWebSocketManager {
  50. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  51. f.write_fmt(format_args!("{}RevisionWebSocketManager", self.object_name))
  52. }
  53. }
  54. impl RevisionWebSocketManager {
  55. pub fn new(
  56. object_name: &str,
  57. object_id: &str,
  58. rev_web_socket: Arc<dyn RevisionWebSocket>,
  59. ws_data_sink: Arc<dyn RevisionWebSocketSink>,
  60. ws_data_stream: Arc<dyn RevisionWSDataStream>,
  61. ping_duration: Duration,
  62. ) -> Self {
  63. let (ws_passthrough_tx, ws_passthrough_rx) = mpsc::channel(1000);
  64. let (stop_sync_tx, _) = tokio::sync::broadcast::channel(2);
  65. let object_id = object_id.to_string();
  66. let object_name = object_name.to_string();
  67. let (state_passthrough_tx, _) = broadcast::channel(2);
  68. let mut manager = RevisionWebSocketManager {
  69. object_id,
  70. object_name,
  71. ws_data_sink,
  72. ws_data_stream,
  73. rev_web_socket,
  74. ws_passthrough_tx,
  75. ws_passthrough_rx: Some(ws_passthrough_rx),
  76. state_passthrough_tx,
  77. stop_sync_tx,
  78. };
  79. manager.run(ping_duration);
  80. manager
  81. }
  82. fn run(&mut self, ping_duration: Duration) {
  83. let ws_passthrough_rx = self.ws_passthrough_rx.take().expect("Only take once");
  84. let sink = RevisionWSSink::new(
  85. &self.object_id,
  86. &self.object_name,
  87. self.ws_data_sink.clone(),
  88. self.rev_web_socket.clone(),
  89. self.stop_sync_tx.subscribe(),
  90. ping_duration,
  91. );
  92. let stream = RevisionWSStream::new(
  93. &self.object_name,
  94. &self.object_id,
  95. self.ws_data_stream.clone(),
  96. ws_passthrough_rx,
  97. self.stop_sync_tx.subscribe(),
  98. );
  99. tokio::spawn(sink.run());
  100. tokio::spawn(stream.run());
  101. }
  102. pub fn scribe_state(&self) -> broadcast::Receiver<WSConnectState> {
  103. self.state_passthrough_tx.subscribe()
  104. }
  105. pub fn stop(&self) {
  106. if self.stop_sync_tx.send(()).is_ok() {
  107. tracing::trace!("{} stop sync", self.object_id)
  108. }
  109. }
  110. #[tracing::instrument(level = "debug", skip(self, data), err)]
  111. pub async fn receive_ws_data(&self, data: ServerRevisionWSData) -> Result<(), FlowyError> {
  112. let _ = self.ws_passthrough_tx.send(data).await.map_err(|e| {
  113. let err_msg = format!("{} passthrough error: {}", self.object_id, e);
  114. FlowyError::internal().context(err_msg)
  115. })?;
  116. Ok(())
  117. }
  118. pub fn connect_state_changed(&self, state: WSConnectState) {
  119. match self.state_passthrough_tx.send(state) {
  120. Ok(_) => {}
  121. Err(e) => tracing::error!("{}", e),
  122. }
  123. }
  124. }
  125. impl std::ops::Drop for RevisionWebSocketManager {
  126. fn drop(&mut self) {
  127. tracing::trace!("{} was dropped", self)
  128. }
  129. }
  130. pub struct RevisionWSStream {
  131. object_name: String,
  132. object_id: String,
  133. consumer: Arc<dyn RevisionWSDataStream>,
  134. ws_msg_rx: Option<mpsc::Receiver<ServerRevisionWSData>>,
  135. stop_rx: Option<SinkStopRx>,
  136. }
  137. impl std::fmt::Display for RevisionWSStream {
  138. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  139. f.write_fmt(format_args!("{}RevisionWSStream", self.object_name))
  140. }
  141. }
  142. impl std::ops::Drop for RevisionWSStream {
  143. fn drop(&mut self) {
  144. tracing::trace!("{} was dropped", self)
  145. }
  146. }
  147. impl RevisionWSStream {
  148. pub fn new(
  149. object_name: &str,
  150. object_id: &str,
  151. consumer: Arc<dyn RevisionWSDataStream>,
  152. ws_msg_rx: mpsc::Receiver<ServerRevisionWSData>,
  153. stop_rx: SinkStopRx,
  154. ) -> Self {
  155. RevisionWSStream {
  156. object_name: object_name.to_string(),
  157. object_id: object_id.to_owned(),
  158. consumer,
  159. ws_msg_rx: Some(ws_msg_rx),
  160. stop_rx: Some(stop_rx),
  161. }
  162. }
  163. pub async fn run(mut self) {
  164. let mut receiver = self.ws_msg_rx.take().expect("Only take once");
  165. let mut stop_rx = self.stop_rx.take().expect("Only take once");
  166. let object_id = self.object_id.clone();
  167. let name = format!("{}", &self);
  168. let stream = stream! {
  169. loop {
  170. tokio::select! {
  171. result = receiver.recv() => {
  172. match result {
  173. Some(msg) => {
  174. yield msg
  175. },
  176. None => {
  177. tracing::debug!("[{}]:{} loop exit", name, object_id);
  178. break;
  179. },
  180. }
  181. },
  182. _ = stop_rx.recv() => {
  183. tracing::debug!("[{}]:{} loop exit", name, object_id);
  184. break
  185. },
  186. };
  187. }
  188. };
  189. stream
  190. .for_each(|msg| async {
  191. match self.handle_message(msg).await {
  192. Ok(_) => {}
  193. Err(e) => tracing::error!("[{}]:{} error: {}", &self, self.object_id, e),
  194. }
  195. })
  196. .await;
  197. }
  198. async fn handle_message(&self, msg: ServerRevisionWSData) -> FlowyResult<()> {
  199. let ServerRevisionWSData { object_id, ty, data } = msg;
  200. let bytes = Bytes::from(data);
  201. match ty {
  202. ServerRevisionWSDataType::ServerPushRev => {
  203. tracing::trace!("[{}]: new push revision: {}:{:?}", self, object_id, ty);
  204. let _ = self.consumer.receive_push_revision(bytes).await?;
  205. }
  206. ServerRevisionWSDataType::ServerPullRev => {
  207. let range = RevisionRange::try_from(bytes)?;
  208. tracing::trace!("[{}]: new pull: {}:{}-{:?}", self, object_id, range, ty);
  209. let _ = self.consumer.pull_revisions_in_range(range).await?;
  210. }
  211. ServerRevisionWSDataType::ServerAck => {
  212. let rev_id = RevId::try_from(bytes).unwrap().value;
  213. tracing::trace!("[{}]: new ack: {}:{}-{:?}", self, object_id, rev_id, ty);
  214. let _ = self.consumer.receive_ack(rev_id.to_string(), ty).await;
  215. }
  216. ServerRevisionWSDataType::UserConnect => {
  217. let new_user = NewDocumentUser::try_from(bytes)?;
  218. let _ = self.consumer.receive_new_user_connect(new_user).await;
  219. }
  220. }
  221. Ok(())
  222. }
  223. }
  224. type SinkStopRx = broadcast::Receiver<()>;
  225. type SinkStopTx = broadcast::Sender<()>;
  226. pub struct RevisionWSSink {
  227. object_id: String,
  228. object_name: String,
  229. provider: Arc<dyn RevisionWebSocketSink>,
  230. rev_web_socket: Arc<dyn RevisionWebSocket>,
  231. stop_rx: Option<SinkStopRx>,
  232. ping_duration: Duration,
  233. }
  234. impl RevisionWSSink {
  235. pub fn new(
  236. object_id: &str,
  237. object_name: &str,
  238. provider: Arc<dyn RevisionWebSocketSink>,
  239. rev_web_socket: Arc<dyn RevisionWebSocket>,
  240. stop_rx: SinkStopRx,
  241. ping_duration: Duration,
  242. ) -> Self {
  243. Self {
  244. object_id: object_id.to_owned(),
  245. object_name: object_name.to_owned(),
  246. provider,
  247. rev_web_socket,
  248. stop_rx: Some(stop_rx),
  249. ping_duration,
  250. }
  251. }
  252. pub async fn run(mut self) {
  253. let (tx, mut rx) = mpsc::channel(1);
  254. let mut stop_rx = self.stop_rx.take().expect("Only take once");
  255. let object_id = self.object_id.clone();
  256. tokio::spawn(tick(tx, self.ping_duration));
  257. let name = format!("{}", self);
  258. let stream = stream! {
  259. loop {
  260. tokio::select! {
  261. result = rx.recv() => {
  262. match result {
  263. Some(msg) => yield msg,
  264. None => break,
  265. }
  266. },
  267. _ = stop_rx.recv() => {
  268. tracing::trace!("[{}]:{} loop exit", name, object_id);
  269. break
  270. },
  271. };
  272. }
  273. };
  274. stream
  275. .for_each(|_| async {
  276. match self.send_next_revision().await {
  277. Ok(_) => {}
  278. Err(e) => tracing::error!("[{}] send failed, {:?}", self, e),
  279. }
  280. })
  281. .await;
  282. }
  283. async fn send_next_revision(&self) -> FlowyResult<()> {
  284. match self.provider.next().await? {
  285. None => {
  286. tracing::trace!("[{}]: Finish synchronizing revisions", self);
  287. Ok(())
  288. }
  289. Some(data) => {
  290. tracing::trace!("[{}]: send {}:{}-{:?}", self, data.object_id, data.id(), data.ty);
  291. self.rev_web_socket.send(data).await
  292. }
  293. }
  294. }
  295. }
  296. async fn tick(sender: mpsc::Sender<()>, duration: Duration) {
  297. let mut interval = interval(duration);
  298. while sender.send(()).await.is_ok() {
  299. interval.tick().await;
  300. }
  301. }
  302. impl std::fmt::Display for RevisionWSSink {
  303. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  304. f.write_fmt(format_args!("{}RevisionWSSink", self.object_name))
  305. }
  306. }
  307. impl std::ops::Drop for RevisionWSSink {
  308. fn drop(&mut self) {
  309. tracing::trace!("{} was dropped", self)
  310. }
  311. }
  312. #[derive(Clone)]
  313. enum Source {
  314. Custom,
  315. Revision,
  316. }
  317. pub trait WSDataProviderDataSource: Send + Sync {
  318. fn next_revision(&self) -> FutureResult<Option<Revision>, FlowyError>;
  319. fn ack_revision(&self, rev_id: i64) -> FutureResult<(), FlowyError>;
  320. fn current_rev_id(&self) -> i64;
  321. }
  322. #[derive(Clone)]
  323. pub struct WSDataProvider {
  324. object_id: String,
  325. rev_ws_data_list: Arc<RwLock<VecDeque<ClientRevisionWSData>>>,
  326. data_source: Arc<dyn WSDataProviderDataSource>,
  327. current_source: Arc<RwLock<Source>>,
  328. }
  329. impl WSDataProvider {
  330. pub fn new(object_id: &str, data_source: Arc<dyn WSDataProviderDataSource>) -> Self {
  331. WSDataProvider {
  332. object_id: object_id.to_owned(),
  333. rev_ws_data_list: Arc::new(RwLock::new(VecDeque::new())),
  334. data_source,
  335. current_source: Arc::new(RwLock::new(Source::Custom)),
  336. }
  337. }
  338. pub async fn push_data(&self, data: ClientRevisionWSData) {
  339. self.rev_ws_data_list.write().await.push_back(data);
  340. }
  341. pub async fn next(&self) -> FlowyResult<Option<ClientRevisionWSData>> {
  342. let source = self.current_source.read().await.clone();
  343. let data = match source {
  344. Source::Custom => match self.rev_ws_data_list.read().await.front() {
  345. None => {
  346. *self.current_source.write().await = Source::Revision;
  347. Ok(None)
  348. }
  349. Some(data) => Ok(Some(data.clone())),
  350. },
  351. Source::Revision => {
  352. if !self.rev_ws_data_list.read().await.is_empty() {
  353. *self.current_source.write().await = Source::Custom;
  354. return Ok(None);
  355. }
  356. match self.data_source.next_revision().await? {
  357. Some(rev) => Ok(Some(ClientRevisionWSData::from_revisions(&self.object_id, vec![rev]))),
  358. None => Ok(Some(ClientRevisionWSData::ping(
  359. &self.object_id,
  360. self.data_source.current_rev_id(),
  361. ))),
  362. }
  363. }
  364. };
  365. data
  366. }
  367. pub async fn ack_data(&self, id: String, _ty: ServerRevisionWSDataType) -> FlowyResult<()> {
  368. let source = self.current_source.read().await.clone();
  369. match source {
  370. Source::Custom => {
  371. let should_pop = match self.rev_ws_data_list.read().await.front() {
  372. None => false,
  373. Some(val) => {
  374. let expected_id = val.id();
  375. if expected_id == id {
  376. true
  377. } else {
  378. tracing::error!("The front element's {} is not equal to the {}", expected_id, id);
  379. false
  380. }
  381. }
  382. };
  383. if should_pop {
  384. let _ = self.rev_ws_data_list.write().await.pop_front();
  385. }
  386. Ok(())
  387. }
  388. Source::Revision => {
  389. let rev_id = id.parse::<i64>().map_err(|e| {
  390. FlowyError::internal().context(format!("Parse {} rev_id from {} failed. {}", self.object_id, id, e))
  391. })?;
  392. let _ = self.data_source.ack_revision(rev_id).await?;
  393. Ok::<(), FlowyError>(())
  394. }
  395. }
  396. }
  397. }
  398. impl ConflictRevisionSink for Arc<WSDataProvider> {
  399. fn send(&self, revisions: Vec<Revision>) -> BoxResultFuture<(), FlowyError> {
  400. let sink = self.clone();
  401. Box::pin(async move {
  402. sink.push_data(ClientRevisionWSData::from_revisions(&sink.object_id, revisions))
  403. .await;
  404. Ok(())
  405. })
  406. }
  407. fn ack(&self, rev_id: String, ty: ServerRevisionWSDataType) -> BoxResultFuture<(), FlowyError> {
  408. let sink = self.clone();
  409. Box::pin(async move { sink.ack_data(rev_id, ty).await })
  410. }
  411. }