rev_manager.rs 13 KB


  1. use crate::rev_queue::{RevCommandSender, RevisionCommand, RevisionQueue};
  2. use crate::{
  3. RevisionPersistence, RevisionSnapshotController, RevisionSnapshotData,
  4. RevisionSnapshotPersistence, WSDataProviderDataSource,
  5. };
  6. use bytes::Bytes;
  7. use flowy_error::{internal_error, FlowyError, FlowyResult};
  8. use lib_infra::future::FutureResult;
  9. use lib_infra::util::md5;
  10. use revision_model::{Revision, RevisionRange};
  11. use std::sync::atomic::AtomicI64;
  12. use std::sync::atomic::Ordering::SeqCst;
  13. use std::sync::Arc;
  14. use tokio::sync::{mpsc, oneshot};
  15. pub trait RevisionCloudService: Send + Sync {
  16. /// Read the object's revision from remote
  17. /// Returns a list of revisions that used to build the object
  18. /// # Arguments
  19. ///
  20. /// * `user_id`: the id of the user
  21. /// * `object_id`: the id of the object
  22. ///
  23. fn fetch_object(&self, user_id: &str, object_id: &str)
  24. -> FutureResult<Vec<Revision>, FlowyError>;
  25. }
  26. pub trait RevisionObjectDeserializer: Send + Sync {
  27. type Output;
  28. /// Deserialize the list of revisions into an concrete object type.
  29. ///
  30. /// # Arguments
  31. ///
  32. /// * `object_id`: the id of the object
  33. /// * `revisions`: a list of revisions that represent the object
  34. ///
  35. fn deserialize_revisions(object_id: &str, revisions: Vec<Revision>) -> FlowyResult<Self::Output>;
  36. fn recover_from_revisions(revisions: Vec<Revision>) -> Option<(Self::Output, i64)>;
  37. }
  38. pub trait RevisionObjectSerializer: Send + Sync {
  39. /// Serialize a list of revisions into one in `Bytes` format
  40. ///
  41. /// * `revisions`: a list of revisions will be serialized to `Bytes`
  42. ///
  43. fn combine_revisions(revisions: Vec<Revision>) -> FlowyResult<Bytes>;
  44. }
  45. /// `RevisionCompress` is used to compress multiple revisions into one revision
  46. ///
  47. pub trait RevisionMergeable: Send + Sync {
  48. fn merge_revisions(
  49. &self,
  50. _user_id: &str,
  51. object_id: &str,
  52. mut revisions: Vec<Revision>,
  53. ) -> FlowyResult<Revision> {
  54. if revisions.is_empty() {
  55. return Err(FlowyError::internal().context("Can't compact the empty revisions"));
  56. }
  57. if revisions.len() == 1 {
  58. return Ok(revisions.pop().unwrap());
  59. }
  60. // Select the last version, making sure version numbers don't overlap
  61. let last_revision = revisions.last().unwrap();
  62. let (base_rev_id, rev_id) = last_revision.pair_rev_id();
  63. let md5 = last_revision.md5.clone();
  64. let bytes = self.combine_revisions(revisions)?;
  65. Ok(Revision::new(object_id, base_rev_id, rev_id, bytes, md5))
  66. }
  67. fn combine_revisions(&self, revisions: Vec<Revision>) -> FlowyResult<Bytes>;
  68. }
  69. pub struct RevisionManager<Connection> {
  70. pub object_id: String,
  71. user_id: String,
  72. rev_id_counter: Arc<RevIdCounter>,
  73. rev_persistence: Arc<RevisionPersistence<Connection>>,
  74. rev_snapshot: Arc<RevisionSnapshotController<Connection>>,
  75. rev_compress: Arc<dyn RevisionMergeable>,
  76. #[cfg(feature = "flowy_unit_test")]
  77. rev_ack_notifier: tokio::sync::broadcast::Sender<i64>,
  78. rev_queue: RevCommandSender,
  79. }
  80. impl<Connection: 'static> RevisionManager<Connection> {
  81. pub fn new<Snapshot, Compress>(
  82. user_id: &str,
  83. object_id: &str,
  84. rev_persistence: RevisionPersistence<Connection>,
  85. rev_compress: Compress,
  86. snapshot_persistence: Snapshot,
  87. ) -> Self
  88. where
  89. Snapshot: 'static + RevisionSnapshotPersistence,
  90. Compress: 'static + RevisionMergeable,
  91. {
  92. let rev_id_counter = Arc::new(RevIdCounter::new(0));
  93. let rev_compress = Arc::new(rev_compress);
  94. let rev_persistence = Arc::new(rev_persistence);
  95. let rev_snapshot = RevisionSnapshotController::new(
  96. user_id,
  97. object_id,
  98. snapshot_persistence,
  99. rev_id_counter.clone(),
  100. rev_persistence.clone(),
  101. rev_compress.clone(),
  102. );
  103. let (rev_queue, receiver) = mpsc::channel(1000);
  104. let queue = RevisionQueue::new(
  105. object_id.to_owned(),
  106. rev_id_counter.clone(),
  107. rev_persistence.clone(),
  108. rev_compress.clone(),
  109. receiver,
  110. );
  111. tokio::spawn(queue.run());
  112. Self {
  113. object_id: object_id.to_string(),
  114. user_id: user_id.to_owned(),
  115. rev_id_counter,
  116. rev_persistence,
  117. rev_snapshot: Arc::new(rev_snapshot),
  118. rev_compress,
  119. #[cfg(feature = "flowy_unit_test")]
  120. rev_ack_notifier: tokio::sync::broadcast::channel(1).0,
  121. rev_queue,
  122. }
  123. }
  124. #[tracing::instrument(name = "revision_manager_initialize", level = "trace", skip_all, fields(deserializer, object_id, deserialize_revisions) err)]
  125. pub async fn initialize<De>(
  126. &mut self,
  127. _cloud: Option<Arc<dyn RevisionCloudService>>,
  128. ) -> FlowyResult<De::Output>
  129. where
  130. De: RevisionObjectDeserializer,
  131. {
  132. let revision_records = self.rev_persistence.load_all_records(&self.object_id)?;
  133. tracing::Span::current().record("object_id", self.object_id.as_str());
  134. tracing::Span::current().record("deserializer", std::any::type_name::<De>());
  135. let revisions: Vec<Revision> = revision_records
  136. .iter()
  137. .map(|record| record.revision.clone())
  138. .collect();
  139. tracing::Span::current().record("deserialize_revisions", revisions.len());
  140. let last_rev_id = revisions
  141. .last()
  142. .as_ref()
  143. .map(|revision| revision.rev_id)
  144. .unwrap_or(0);
  145. match De::deserialize_revisions(&self.object_id, revisions.clone()) {
  146. Ok(object) => {
  147. self
  148. .rev_persistence
  149. .sync_revision_records(&revision_records)
  150. .await?;
  151. self.rev_id_counter.set(last_rev_id);
  152. Ok(object)
  153. },
  154. Err(e) => match self.rev_snapshot.restore_from_snapshot::<De>(last_rev_id) {
  155. None => {
  156. tracing::info!("[Restore] iterate restore from each revision");
  157. let (output, recover_rev_id) = De::recover_from_revisions(revisions).ok_or(e)?;
  158. tracing::info!(
  159. "[Restore] last_rev_id:{}, recover_rev_id: {}",
  160. last_rev_id,
  161. recover_rev_id
  162. );
  163. self.rev_id_counter.set(recover_rev_id);
  164. // delete the revisions whose rev_id is greater than recover_rev_id
  165. if recover_rev_id < last_rev_id {
  166. let range = RevisionRange {
  167. start: recover_rev_id + 1,
  168. end: last_rev_id,
  169. };
  170. tracing::info!("[Restore] delete revisions in range: {}", range);
  171. let _ = self.rev_persistence.delete_revisions_from_range(range);
  172. }
  173. Ok(output)
  174. },
  175. Some((object, snapshot_rev)) => {
  176. let snapshot_rev_id = snapshot_rev.rev_id;
  177. let _ = self.rev_persistence.reset(vec![snapshot_rev]).await;
  178. // revision_records.retain(|record| record.revision.rev_id <= snapshot_rev_id);
  179. // let _ = self.rev_persistence.sync_revision_records(&revision_records).await?;
  180. self.rev_id_counter.set(snapshot_rev_id);
  181. Ok(object)
  182. },
  183. },
  184. }
  185. }
  186. pub async fn close(&self) {
  187. let _ = self
  188. .rev_persistence
  189. .merge_lagging_revisions(&self.rev_compress)
  190. .await;
  191. }
  192. pub async fn generate_snapshot(&self) {
  193. self.rev_snapshot.generate_snapshot().await;
  194. }
  195. pub async fn read_snapshot(
  196. &self,
  197. rev_id: Option<i64>,
  198. ) -> FlowyResult<Option<RevisionSnapshotData>> {
  199. match rev_id {
  200. None => self.rev_snapshot.read_last_snapshot(),
  201. Some(rev_id) => self.rev_snapshot.read_snapshot(rev_id),
  202. }
  203. }
  204. pub async fn load_revisions(&self) -> FlowyResult<Vec<Revision>> {
  205. let revisions = RevisionLoader {
  206. object_id: self.object_id.clone(),
  207. user_id: self.user_id.clone(),
  208. cloud: None,
  209. rev_persistence: self.rev_persistence.clone(),
  210. }
  211. .load_revisions()
  212. .await?;
  213. Ok(revisions)
  214. }
  215. #[tracing::instrument(level = "trace", skip(self, revisions), err)]
  216. pub async fn reset_object(&self, revisions: Vec<Revision>) -> FlowyResult<()> {
  217. let rev_id = pair_rev_id_from_revisions(&revisions).1;
  218. self.rev_persistence.reset(revisions).await?;
  219. self.rev_id_counter.set(rev_id);
  220. Ok(())
  221. }
  222. #[tracing::instrument(level = "debug", skip(self, revision), err)]
  223. pub async fn add_remote_revision(&self, revision: &Revision) -> Result<(), FlowyError> {
  224. if revision.bytes.is_empty() {
  225. return Err(FlowyError::internal().context("Remote revisions is empty"));
  226. }
  227. self.rev_persistence.add_ack_revision(revision).await?;
  228. self.rev_id_counter.set(revision.rev_id);
  229. Ok(())
  230. }
  231. /// Adds the revision that generated by user editing
  232. // #[tracing::instrument(level = "trace", skip_all, err)]
  233. pub async fn add_local_revision(
  234. &self,
  235. data: Bytes,
  236. object_md5: String,
  237. ) -> Result<i64, FlowyError> {
  238. if data.is_empty() {
  239. return Err(FlowyError::internal().context("The data of the revisions is empty"));
  240. }
  241. self.rev_snapshot.generate_snapshot_if_need();
  242. let (ret, rx) = oneshot::channel();
  243. self
  244. .rev_queue
  245. .send(RevisionCommand::RevisionData {
  246. data,
  247. object_md5,
  248. ret,
  249. })
  250. .await
  251. .map_err(internal_error)?;
  252. rx.await.map_err(internal_error)?
  253. }
  254. #[tracing::instrument(level = "debug", skip(self), err)]
  255. pub async fn ack_revision(&self, rev_id: i64) -> Result<(), FlowyError> {
  256. if self.rev_persistence.ack_revision(rev_id).await.is_ok() {
  257. #[cfg(feature = "flowy_unit_test")]
  258. let _ = self.rev_ack_notifier.send(rev_id);
  259. }
  260. Ok(())
  261. }
  262. /// Returns the current revision id
  263. pub fn rev_id(&self) -> i64 {
  264. self.rev_id_counter.value()
  265. }
  266. pub async fn next_sync_rev_id(&self) -> Option<i64> {
  267. self.rev_persistence.next_sync_rev_id().await
  268. }
  269. pub fn next_rev_id_pair(&self) -> (i64, i64) {
  270. let cur = self.rev_id_counter.value();
  271. let next = self.rev_id_counter.next_id();
  272. (cur, next)
  273. }
  274. pub fn number_of_sync_revisions(&self) -> usize {
  275. self.rev_persistence.number_of_sync_records()
  276. }
  277. pub fn number_of_revisions_in_disk(&self) -> usize {
  278. self.rev_persistence.number_of_records_in_disk()
  279. }
  280. pub async fn get_revisions_in_range(
  281. &self,
  282. range: RevisionRange,
  283. ) -> Result<Vec<Revision>, FlowyError> {
  284. let revisions = self.rev_persistence.revisions_in_range(&range).await?;
  285. Ok(revisions)
  286. }
  287. pub async fn next_sync_revision(&self) -> FlowyResult<Option<Revision>> {
  288. self.rev_persistence.next_sync_revision().await
  289. }
  290. pub async fn get_revision(&self, rev_id: i64) -> Option<Revision> {
  291. self
  292. .rev_persistence
  293. .get(rev_id)
  294. .await
  295. .map(|record| record.revision)
  296. }
  297. }
  298. impl<Connection: 'static> WSDataProviderDataSource for Arc<RevisionManager<Connection>> {
  299. fn next_revision(&self) -> FutureResult<Option<Revision>, FlowyError> {
  300. let rev_manager = self.clone();
  301. FutureResult::new(async move { rev_manager.next_sync_revision().await })
  302. }
  303. fn ack_revision(&self, rev_id: i64) -> FutureResult<(), FlowyError> {
  304. let rev_manager = self.clone();
  305. FutureResult::new(async move { (*rev_manager).ack_revision(rev_id).await })
  306. }
  307. fn current_rev_id(&self) -> i64 {
  308. self.rev_id()
  309. }
  310. }
  311. #[cfg(feature = "flowy_unit_test")]
  312. impl<Connection: 'static> RevisionManager<Connection> {
  313. pub async fn revision_cache(&self) -> Arc<RevisionPersistence<Connection>> {
  314. self.rev_persistence.clone()
  315. }
  316. pub fn ack_notify(&self) -> tokio::sync::broadcast::Receiver<i64> {
  317. self.rev_ack_notifier.subscribe()
  318. }
  319. pub fn get_all_revision_records(
  320. &self,
  321. ) -> FlowyResult<Vec<flowy_revision_persistence::SyncRecord>> {
  322. self.rev_persistence.load_all_records(&self.object_id)
  323. }
  324. }
  325. pub struct RevisionLoader<Connection> {
  326. pub object_id: String,
  327. pub user_id: String,
  328. pub cloud: Option<Arc<dyn RevisionCloudService>>,
  329. pub rev_persistence: Arc<RevisionPersistence<Connection>>,
  330. }
  331. impl<Connection: 'static> RevisionLoader<Connection> {
  332. pub async fn load_revisions(&self) -> Result<Vec<Revision>, FlowyError> {
  333. let records = self.rev_persistence.load_all_records(&self.object_id)?;
  334. let revisions = records
  335. .into_iter()
  336. .map(|record| record.revision)
  337. .collect::<_>();
  338. Ok(revisions)
  339. }
  340. }
  341. /// Represents as the md5 of the revision object after applying the
  342. /// revision. For example, RevisionMD5 will be the md5 of the document
  343. /// content.
  344. #[derive(Debug, Clone)]
  345. pub struct RevisionMD5(String);
  346. impl RevisionMD5 {
  347. pub fn from_bytes<T: AsRef<[u8]>>(bytes: T) -> Result<Self, FlowyError> {
  348. Ok(RevisionMD5(md5(bytes)))
  349. }
  350. pub fn into_inner(self) -> String {
  351. self.0
  352. }
  353. pub fn is_equal(&self, s: &str) -> bool {
  354. self.0 == s
  355. }
  356. }
  357. impl std::convert::From<RevisionMD5> for String {
  358. fn from(md5: RevisionMD5) -> Self {
  359. md5.0
  360. }
  361. }
  362. impl std::convert::From<&str> for RevisionMD5 {
  363. fn from(s: &str) -> Self {
  364. Self(s.to_owned())
  365. }
  366. }
  367. impl std::convert::From<String> for RevisionMD5 {
  368. fn from(s: String) -> Self {
  369. Self(s)
  370. }
  371. }
  372. impl std::ops::Deref for RevisionMD5 {
  373. type Target = String;
  374. fn deref(&self) -> &Self::Target {
  375. &self.0
  376. }
  377. }
  378. impl PartialEq<Self> for RevisionMD5 {
  379. fn eq(&self, other: &Self) -> bool {
  380. self.0 == other.0
  381. }
  382. }
  383. impl std::cmp::Eq for RevisionMD5 {}
  384. fn pair_rev_id_from_revisions(revisions: &[Revision]) -> (i64, i64) {
  385. let mut rev_id = 0;
  386. revisions.iter().for_each(|revision| {
  387. if rev_id < revision.rev_id {
  388. rev_id = revision.rev_id;
  389. }
  390. });
  391. if rev_id > 0 {
  392. (rev_id - 1, rev_id)
  393. } else {
  394. (0, rev_id)
  395. }
  396. }
  397. #[derive(Debug)]
  398. pub struct RevIdCounter(pub AtomicI64);
  399. impl RevIdCounter {
  400. pub fn new(n: i64) -> Self {
  401. Self(AtomicI64::new(n))
  402. }
  403. pub fn next_id(&self) -> i64 {
  404. let _ = self.0.fetch_add(1, SeqCst);
  405. self.value()
  406. }
  407. pub fn value(&self) -> i64 {
  408. self.0.load(SeqCst)
  409. }
  410. pub fn set(&self, n: i64) {
  411. let _ = self.0.fetch_update(SeqCst, SeqCst, |_| Some(n));
  412. }
  413. }