Browse Source

reset document data using transaction

appflowy 3 năm trước cách đây
mục cha
commit
e4e40ebe20

+ 56 - 36
backend/src/services/document/persistence.rs

@@ -1,6 +1,6 @@
 use crate::{
     context::FlowyPersistence,
-    services::kv::{KVStore, KeyValue},
+    services::kv::{KVStore, KVTransaction, KeyValue},
     util::serde_ext::parse_from_bytes,
 };
 use anyhow::Context;
@@ -45,11 +45,20 @@ pub(crate) async fn read_document(
 #[tracing::instrument(level = "debug", skip(kv_store, params), fields(delta), err)]
 pub async fn reset_document(
     kv_store: &Arc<DocumentKVPersistence>,
-    params: ResetDocumentParams,
+    mut params: ResetDocumentParams,
 ) -> Result<(), ServerError> {
-    // TODO: Reset document requires atomic operation
-    // let _ = kv_store.batch_delete_revisions(&doc_id.to_string(), None).await?;
-    todo!()
+    let revisions = params.take_revisions().take_items();
+    let doc_id = params.take_doc_id();
+    kv_store
+        .transaction(|mut transaction| {
+            Box::pin(async move {
+                let _ = transaction.batch_delete_key_start_with(&doc_id).await?;
+                let items = revisions_to_key_value_items(revisions.into());
+                let _ = transaction.batch_set(items).await?;
+                Ok(())
+            })
+        })
+        .await
 }
 
 #[tracing::instrument(level = "debug", skip(kv_store), err)]
@@ -59,11 +68,11 @@ pub(crate) async fn delete_document(kv_store: &Arc<DocumentKVPersistence>, doc_i
 }
 
 pub struct DocumentKVPersistence {
-    inner: Arc<dyn KVStore>,
+    inner: Arc<KVStore>,
 }
 
 impl std::ops::Deref for DocumentKVPersistence {
-    type Target = Arc<dyn KVStore>;
+    type Target = Arc<KVStore>;
 
     fn deref(&self) -> &Self::Target { &self.inner }
 }
@@ -73,34 +82,21 @@ impl std::ops::DerefMut for DocumentKVPersistence {
 }
 
 impl DocumentKVPersistence {
-    pub(crate) fn new(kv_store: Arc<dyn KVStore>) -> Self { DocumentKVPersistence { inner: kv_store } }
+    pub(crate) fn new(kv_store: Arc<KVStore>) -> Self { DocumentKVPersistence { inner: kv_store } }
 
     pub(crate) async fn batch_set_revision(&self, revisions: Vec<Revision>) -> Result<(), ServerError> {
-        let kv_store = self.inner.clone();
-        let items = revisions
-            .into_iter()
-            .map(|revision| {
-                let key = make_revision_key(&revision.doc_id, revision.rev_id);
-                let value = Bytes::from(revision.write_to_bytes().unwrap());
-                KeyValue { key, value }
-            })
-            .collect::<Vec<KeyValue>>();
-        let _ = kv_store.batch_set(items).await?;
-        // use futures::stream::{self, StreamExt};
-        // let f = |revision: Revision, kv_store: Arc<dyn KVStore>| async move {
-        //     let key = make_revision_key(&revision.doc_id, revision.rev_id);
-        //     let bytes = revision.write_to_bytes().unwrap();
-        //     let _ = kv_store.set(&key, Bytes::from(bytes)).await.unwrap();
-        // };
-        //
-        // stream::iter(revisions)
-        //     .for_each_concurrent(None, |revision| f(revision, kv_store.clone()))
-        //     .await;
-        Ok(())
+        let items = revisions_to_key_value_items(revisions);
+        self.inner
+            .transaction(|mut t| Box::pin(async move { t.batch_set(items).await }))
+            .await
     }
 
     pub(crate) async fn get_doc_revisions(&self, doc_id: &str) -> Result<RepeatedRevision, ServerError> {
-        let items = self.inner.batch_get_start_with(doc_id).await?;
+        let doc_id = doc_id.to_owned();
+        let items = self
+            .inner
+            .transaction(|mut t| Box::pin(async move { t.batch_get_start_with(&doc_id).await }))
+            .await?;
         Ok(key_value_items_to_revisions(items))
     }
 
@@ -111,13 +107,21 @@ impl DocumentKVPersistence {
     ) -> Result<RepeatedRevision, ServerError> {
         let rev_ids = rev_ids.into();
         let items = match rev_ids {
-            None => self.inner.batch_get_start_with(doc_id).await?,
+            None => {
+                let doc_id = doc_id.to_owned();
+                self.inner
+                    .transaction(|mut t| Box::pin(async move { t.batch_get_start_with(&doc_id).await }))
+                    .await?
+            },
             Some(rev_ids) => {
                 let keys = rev_ids
                     .into_iter()
                     .map(|rev_id| make_revision_key(doc_id, rev_id))
                     .collect::<Vec<String>>();
-                self.inner.batch_get(keys).await?
+
+                self.inner
+                    .transaction(|mut t| Box::pin(async move { t.batch_get(keys).await }))
+                    .await?
             },
         };
 
@@ -131,21 +135,37 @@ impl DocumentKVPersistence {
     ) -> Result<(), ServerError> {
         match rev_ids.into() {
             None => {
-                let _ = self.inner.batch_delete_key_start_with(doc_id).await?;
-                Ok(())
+                let doc_id = doc_id.to_owned();
+                self.inner
+                    .transaction(|mut t| Box::pin(async move { t.batch_delete_key_start_with(&doc_id).await }))
+                    .await
             },
             Some(rev_ids) => {
                 let keys = rev_ids
                     .into_iter()
                     .map(|rev_id| make_revision_key(doc_id, rev_id))
                     .collect::<Vec<String>>();
-                let _ = self.inner.batch_delete(keys).await?;
-                Ok(())
+
+                self.inner
+                    .transaction(|mut t| Box::pin(async move { t.batch_delete(keys).await }))
+                    .await
             },
         }
     }
 }
 
+#[inline]
+fn revisions_to_key_value_items(revisions: Vec<Revision>) -> Vec<KeyValue> {
+    revisions
+        .into_iter()
+        .map(|revision| {
+            let key = make_revision_key(&revision.doc_id, revision.rev_id);
+            let value = Bytes::from(revision.write_to_bytes().unwrap());
+            KeyValue { key, value }
+        })
+        .collect::<Vec<KeyValue>>()
+}
+
 #[inline]
 fn key_value_items_to_revisions(items: Vec<KeyValue>) -> RepeatedRevision {
     let mut revisions = items

+ 100 - 131
backend/src/services/kv/kv.rs

@@ -1,5 +1,5 @@
 use crate::{
-    services::kv::{KVAction, KVStore, KeyValue},
+    services::kv::{KVStore, KVTransaction, KeyValue},
     util::sqlx_ext::{map_sqlx_error, DBTransaction, SqlBuilder},
 };
 use anyhow::Context;
@@ -17,27 +17,26 @@ use sqlx::{
     Postgres,
     Row,
 };
-use std::{future::Future, pin::Pin};
+use std::{future::Future, pin::Pin, sync::Arc};
 
 const KV_TABLE: &str = "kv_table";
 
-pub(crate) struct PostgresKV {
+pub struct PostgresKV {
     pub(crate) pg_pool: PgPool,
 }
 
 impl PostgresKV {
-    async fn transaction<F, O>(&self, f: F) -> Result<O, ServerError>
+    pub async fn transaction<F, O>(&self, f: F) -> Result<O, ServerError>
     where
-        F: for<'a> FnOnce(&'a mut DBTransaction<'_>) -> BoxFuture<'a, Result<O, ServerError>>,
+        F: for<'a> FnOnce(Box<dyn KVTransaction + 'a>) -> BoxResultFuture<O, ServerError>,
     {
         let mut transaction = self
             .pg_pool
             .begin()
             .await
             .context("[KV]:Failed to acquire a Postgres connection")?;
-
-        let result = f(&mut transaction).await;
-
+        let postgres_transaction = PostgresTransaction(&mut transaction);
+        let result = f(Box::new(postgres_transaction)).await;
         transaction
             .commit()
             .await
@@ -47,43 +46,32 @@ impl PostgresKV {
     }
 }
 
-impl KVStore for PostgresKV {}
-
-pub(crate) struct PostgresTransaction<'a> {
-    pub(crate) transaction: DBTransaction<'a>,
-}
-
-impl<'a> PostgresTransaction<'a> {}
+pub(crate) struct PostgresTransaction<'a, 'b>(&'a mut DBTransaction<'b>);
 
 #[async_trait]
-impl KVAction for PostgresKV {
-    async fn get(&self, key: &str) -> Result<Option<Bytes>, ServerError> {
+impl<'a, 'b> KVTransaction for PostgresTransaction<'a, 'b> {
+    async fn get(&mut self, key: &str) -> Result<Option<Bytes>, ServerError> {
         let id = key.to_string();
-        self.transaction(|transaction| {
-            Box::pin(async move {
-                let (sql, args) = SqlBuilder::select(KV_TABLE)
-                    .add_field("*")
-                    .and_where_eq("id", &id)
-                    .build()?;
-
-                let result = sqlx::query_as_with::<Postgres, KVTable, PgArguments>(&sql, args)
-                    .fetch_one(transaction)
-                    .await;
-
-                let result = match result {
-                    Ok(val) => Ok(Some(Bytes::from(val.blob))),
-                    Err(error) => match error {
-                        Error::RowNotFound => Ok(None),
-                        _ => Err(map_sqlx_error(error)),
-                    },
-                };
-                result
-            })
-        })
-        .await
+        let (sql, args) = SqlBuilder::select(KV_TABLE)
+            .add_field("*")
+            .and_where_eq("id", &id)
+            .build()?;
+
+        let result = sqlx::query_as_with::<Postgres, KVTable, PgArguments>(&sql, args)
+            .fetch_one(self.0 as &mut DBTransaction<'b>)
+            .await;
+
+        let result = match result {
+            Ok(val) => Ok(Some(Bytes::from(val.blob))),
+            Err(error) => match error {
+                Error::RowNotFound => Ok(None),
+                _ => Err(map_sqlx_error(error)),
+            },
+        };
+        result
     }
 
-    async fn set(&self, key: &str, bytes: Bytes) -> Result<(), ServerError> {
+    async fn set(&mut self, key: &str, bytes: Bytes) -> Result<(), ServerError> {
         self.batch_set(vec![KeyValue {
             key: key.to_string(),
             value: bytes,
@@ -91,115 +79,96 @@ impl KVAction for PostgresKV {
         .await
     }
 
-    async fn remove(&self, key: &str) -> Result<(), ServerError> {
+    async fn remove(&mut self, key: &str) -> Result<(), ServerError> {
         let id = key.to_string();
-        self.transaction(|transaction| {
-            Box::pin(async move {
-                let (sql, args) = SqlBuilder::delete(KV_TABLE).and_where_eq("id", &id).build()?;
-                let _ = sqlx::query_with(&sql, args)
-                    .execute(transaction)
-                    .await
-                    .map_err(map_sqlx_error)?;
-                Ok(())
-            })
-        })
-        .await
+        let (sql, args) = SqlBuilder::delete(KV_TABLE).and_where_eq("id", &id).build()?;
+        let _ = sqlx::query_with(&sql, args)
+            .execute(self.0 as &mut DBTransaction<'_>)
+            .await
+            .map_err(map_sqlx_error)?;
+        Ok(())
     }
 
-    async fn batch_set(&self, kvs: Vec<KeyValue>) -> Result<(), ServerError> {
-        self.transaction(|transaction| {
-            Box::pin(async move {
-                let mut builder = RawSqlBuilder::insert_into(KV_TABLE);
-                let m_builder = builder.field("id").field("blob");
-
-                let mut args = PgArguments::default();
-                kvs.iter().enumerate().for_each(|(index, _)| {
-                    let index = index * 2 + 1;
-                    m_builder.values(&[format!("${}", index), format!("${}", index + 1)]);
-                });
-
-                for kv in kvs {
-                    args.add(kv.key);
-                    args.add(kv.value.to_vec());
-                }
-
-                let sql = m_builder.sql()?;
-                let _ = sqlx::query_with(&sql, args)
-                    .execute(transaction)
-                    .await
-                    .map_err(map_sqlx_error)?;
-
-                Ok::<(), ServerError>(())
-            })
-        })
-        .await
+    async fn batch_set(&mut self, kvs: Vec<KeyValue>) -> Result<(), ServerError> {
+        let mut builder = RawSqlBuilder::insert_into(KV_TABLE);
+        let m_builder = builder.field("id").field("blob");
+
+        let mut args = PgArguments::default();
+        kvs.iter().enumerate().for_each(|(index, _)| {
+            let index = index * 2 + 1;
+            m_builder.values(&[format!("${}", index), format!("${}", index + 1)]);
+        });
+
+        for kv in kvs {
+            args.add(kv.key);
+            args.add(kv.value.to_vec());
+        }
+
+        let sql = m_builder.sql()?;
+        let _ = sqlx::query_with(&sql, args)
+            .execute(self.0 as &mut DBTransaction<'_>)
+            .await
+            .map_err(map_sqlx_error)?;
+
+        Ok::<(), ServerError>(())
     }
 
-    async fn batch_get(&self, keys: Vec<String>) -> Result<Vec<KeyValue>, ServerError> {
-        self.transaction(|transaction| {
-            Box::pin(async move {
-                let sql = RawSqlBuilder::select_from(KV_TABLE)
-                    .field("id")
-                    .field("blob")
-                    .and_where_in_quoted("id", &keys)
-                    .sql()?;
-
-                let rows = sqlx::query(&sql).fetch_all(transaction).await.map_err(map_sqlx_error)?;
-                let kvs = rows_to_key_values(rows);
-                Ok::<Vec<KeyValue>, ServerError>(kvs)
-            })
-        })
-        .await
+    async fn batch_get(&mut self, keys: Vec<String>) -> Result<Vec<KeyValue>, ServerError> {
+        let sql = RawSqlBuilder::select_from(KV_TABLE)
+            .field("id")
+            .field("blob")
+            .and_where_in_quoted("id", &keys)
+            .sql()?;
+
+        let rows = sqlx::query(&sql)
+            .fetch_all(self.0 as &mut DBTransaction<'_>)
+            .await
+            .map_err(map_sqlx_error)?;
+        let kvs = rows_to_key_values(rows);
+        Ok::<Vec<KeyValue>, ServerError>(kvs)
     }
 
-    async fn batch_delete(&self, keys: Vec<String>) -> Result<(), ServerError> {
-        self.transaction(|transaction| {
-            Box::pin(async move {
-                let sql = RawSqlBuilder::delete_from(KV_TABLE).and_where_in("id", &keys).sql()?;
-                let _ = sqlx::query(&sql).execute(transaction).await.map_err(map_sqlx_error)?;
+    async fn batch_delete(&mut self, keys: Vec<String>) -> Result<(), ServerError> {
+        let sql = RawSqlBuilder::delete_from(KV_TABLE).and_where_in("id", &keys).sql()?;
+        let _ = sqlx::query(&sql)
+            .execute(self.0 as &mut DBTransaction<'_>)
+            .await
+            .map_err(map_sqlx_error)?;
 
-                Ok::<(), ServerError>(())
-            })
-        })
-        .await
+        Ok::<(), ServerError>(())
     }
 
-    async fn batch_get_start_with(&self, key: &str) -> Result<Vec<KeyValue>, ServerError> {
+    async fn batch_get_start_with(&mut self, key: &str) -> Result<Vec<KeyValue>, ServerError> {
         let prefix = key.to_owned();
-        self.transaction(|transaction| {
-            Box::pin(async move {
-                let sql = RawSqlBuilder::select_from(KV_TABLE)
-                    .field("id")
-                    .field("blob")
-                    .and_where_like_left("id", &prefix)
-                    .sql()?;
-
-                let rows = sqlx::query(&sql).fetch_all(transaction).await.map_err(map_sqlx_error)?;
+        let sql = RawSqlBuilder::select_from(KV_TABLE)
+            .field("id")
+            .field("blob")
+            .and_where_like_left("id", &prefix)
+            .sql()?;
+
+        let rows = sqlx::query(&sql)
+            .fetch_all(self.0 as &mut DBTransaction<'_>)
+            .await
+            .map_err(map_sqlx_error)?;
 
-                let kvs = rows_to_key_values(rows);
+        let kvs = rows_to_key_values(rows);
 
-                Ok::<Vec<KeyValue>, ServerError>(kvs)
-            })
-        })
-        .await
+        Ok::<Vec<KeyValue>, ServerError>(kvs)
     }
 
-    async fn batch_delete_key_start_with(&self, keyword: &str) -> Result<(), ServerError> {
+    async fn batch_delete_key_start_with(&mut self, keyword: &str) -> Result<(), ServerError> {
         let keyword = keyword.to_owned();
-        self.transaction(|transaction| {
-            Box::pin(async move {
-                let sql = RawSqlBuilder::delete_from(KV_TABLE)
-                    .and_where_like_left("id", &keyword)
-                    .sql()?;
-
-                let _ = sqlx::query(&sql).execute(transaction).await.map_err(map_sqlx_error)?;
-                Ok::<(), ServerError>(())
-            })
-        })
-        .await
+        let sql = RawSqlBuilder::delete_from(KV_TABLE)
+            .and_where_like_left("id", &keyword)
+            .sql()?;
+
+        let _ = sqlx::query(&sql)
+            .execute(self.0 as &mut DBTransaction<'_>)
+            .await
+            .map_err(map_sqlx_error)?;
+        Ok::<(), ServerError>(())
     }
 }
-
 fn rows_to_key_values(rows: Vec<PgRow>) -> Vec<KeyValue> {
     rows.into_iter()
         .map(|row| {

+ 19 - 18
backend/src/services/kv/mod.rs

@@ -5,16 +5,16 @@ use async_trait::async_trait;
 use bytes::Bytes;
 use futures_core::future::BoxFuture;
 pub(crate) use kv::*;
+use std::sync::Arc;
 
 use backend_service::errors::ServerError;
 use lib_infra::future::{BoxResultFuture, FutureResultSend};
 
-#[derive(Clone, Debug, PartialEq, Eq)]
-pub struct KeyValue {
-    pub key: String,
-    pub value: Bytes,
-}
+// TODO: Generic the KVStore that enable switching KVStore to another
+// implementation
+pub type KVStore = PostgresKV;
 
+#[rustfmt::skip]
 // https://rust-lang.github.io/async-book/07_workarounds/05_async_in_traits.html
 // Note that using these trait methods will result in a heap allocation
 // per-function-call. This is not a significant cost for the vast majority of
@@ -22,20 +22,21 @@ pub struct KeyValue {
 // functionality in the public API of a low-level function that is expected to
 // be called millions of times a second.
 #[async_trait]
-pub trait KVStore: KVAction + Send + Sync {}
-
-// pub trait KVTransaction
+pub trait KVTransaction: Send + Sync {
+    async fn get(&mut self, key: &str) -> Result<Option<Bytes>, ServerError>;
+    async fn set(&mut self, key: &str, value: Bytes) -> Result<(), ServerError>;
+    async fn remove(&mut self, key: &str) -> Result<(), ServerError>;
 
-#[async_trait]
-pub trait KVAction: Send + Sync {
-    async fn get(&self, key: &str) -> Result<Option<Bytes>, ServerError>;
-    async fn set(&self, key: &str, value: Bytes) -> Result<(), ServerError>;
-    async fn remove(&self, key: &str) -> Result<(), ServerError>;
+    async fn batch_set(&mut self, kvs: Vec<KeyValue>) -> Result<(), ServerError>;
+    async fn batch_get(&mut self, keys: Vec<String>) -> Result<Vec<KeyValue>, ServerError>;
+    async fn batch_delete(&mut self, keys: Vec<String>) -> Result<(), ServerError>;
 
-    async fn batch_set(&self, kvs: Vec<KeyValue>) -> Result<(), ServerError>;
-    async fn batch_get(&self, keys: Vec<String>) -> Result<Vec<KeyValue>, ServerError>;
-    async fn batch_delete(&self, keys: Vec<String>) -> Result<(), ServerError>;
+    async fn batch_get_start_with(&mut self, key: &str) -> Result<Vec<KeyValue>, ServerError>;
+    async fn batch_delete_key_start_with(&mut self, keyword: &str) -> Result<(), ServerError>;
+}
 
-    async fn batch_get_start_with(&self, key: &str) -> Result<Vec<KeyValue>, ServerError>;
-    async fn batch_delete_key_start_with(&self, keyword: &str) -> Result<(), ServerError>;
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct KeyValue {
+    pub key: String,
+    pub value: Bytes,
 }

+ 1 - 0
backend/tests/api_test/kv_test.rs

@@ -41,6 +41,7 @@ async fn kv_batch_set_test() {
             value: "b".to_string().into(),
         },
     ];
+
     kv.batch_set(kvs.clone()).await.unwrap();
     let kvs_from_db = kv
         .batch_get(kvs.clone().into_iter().map(|value| value.key).collect::<Vec<String>>())