user_session.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. use std::sync::Arc;
  2. use appflowy_integrate::RocksCollabDB;
  3. use serde::{Deserialize, Serialize};
  4. use serde_repr::*;
  5. use tokio::sync::RwLock;
  6. use flowy_error::internal_error;
  7. use flowy_sqlite::ConnectionPool;
  8. use flowy_sqlite::{
  9. kv::KV,
  10. query_dsl::*,
  11. schema::{user_table, user_table::dsl},
  12. DBConnection, ExpressionMethods, UserDatabaseConnection,
  13. };
  14. use lib_infra::box_any::BoxAny;
  15. use crate::entities::{
  16. AuthTypePB, SignInResponse, SignUpResponse, UpdateUserProfileParams, UserProfile,
  17. };
  18. use crate::entities::{UserProfilePB, UserSettingPB};
  19. use crate::event_map::{DefaultUserStatusCallback, UserCloudServiceProvider, UserStatusCallback};
  20. use crate::{
  21. errors::FlowyError,
  22. event_map::UserAuthService,
  23. notification::*,
  24. services::database::{UserDB, UserTable, UserTableChangeset},
  25. };
  26. pub struct UserSessionConfig {
  27. root_dir: String,
  28. /// Used as the key of `Session` when saving session information to KV.
  29. session_cache_key: String,
  30. }
  31. impl UserSessionConfig {
  32. /// The `root_dir` represents as the root of the user folders. It must be unique for each
  33. /// users.
  34. pub fn new(name: &str, root_dir: &str) -> Self {
  35. let session_cache_key = format!("{}_session_cache", name);
  36. Self {
  37. root_dir: root_dir.to_owned(),
  38. session_cache_key,
  39. }
  40. }
  41. }
  42. pub struct UserSession {
  43. database: UserDB,
  44. session_config: UserSessionConfig,
  45. cloud_services: Arc<dyn UserCloudServiceProvider>,
  46. user_status_callback: RwLock<Arc<dyn UserStatusCallback>>,
  47. }
  48. impl UserSession {
  49. pub fn new(
  50. session_config: UserSessionConfig,
  51. cloud_services: Arc<dyn UserCloudServiceProvider>,
  52. ) -> Self {
  53. let db = UserDB::new(&session_config.root_dir);
  54. let user_status_callback: RwLock<Arc<dyn UserStatusCallback>> =
  55. RwLock::new(Arc::new(DefaultUserStatusCallback));
  56. Self {
  57. database: db,
  58. session_config,
  59. cloud_services,
  60. user_status_callback,
  61. }
  62. }
  63. pub async fn init<C: UserStatusCallback + 'static>(&self, user_status_callback: C) {
  64. if let Ok(session) = self.get_session() {
  65. let _ = user_status_callback
  66. .did_sign_in(session.user_id, &session.workspace_id)
  67. .await;
  68. }
  69. *self.user_status_callback.write().await = Arc::new(user_status_callback);
  70. }
  71. pub fn db_connection(&self) -> Result<DBConnection, FlowyError> {
  72. let user_id = self.get_session()?.user_id;
  73. self.database.get_connection(user_id)
  74. }
  75. // The caller will be not 'Sync' before of the return value,
  76. // PooledConnection<ConnectionManager> is not sync. You can use
  77. // db_connection_pool function to require the ConnectionPool that is 'Sync'.
  78. //
  79. // let pool = self.db_connection_pool()?;
  80. // let conn: PooledConnection<ConnectionManager> = pool.get()?;
  81. pub fn db_pool(&self) -> Result<Arc<ConnectionPool>, FlowyError> {
  82. let user_id = self.get_session()?.user_id;
  83. self.database.get_pool(user_id)
  84. }
  85. pub fn get_collab_db(&self) -> Result<Arc<RocksCollabDB>, FlowyError> {
  86. let user_id = self.get_session()?.user_id;
  87. self.database.get_kv_db(user_id)
  88. }
  89. #[tracing::instrument(level = "debug", skip(self, params))]
  90. pub async fn sign_in(
  91. &self,
  92. auth_type: &AuthType,
  93. params: BoxAny,
  94. ) -> Result<UserProfile, FlowyError> {
  95. self
  96. .user_status_callback
  97. .read()
  98. .await
  99. .auth_type_did_changed(auth_type.clone());
  100. self.cloud_services.set_auth_type(auth_type.clone());
  101. let resp = self
  102. .cloud_services
  103. .get_auth_service()?
  104. .sign_in(params)
  105. .await?;
  106. let session: Session = resp.clone().into();
  107. self.set_session(Some(session))?;
  108. let user_profile: UserProfile = self.save_user(resp.into()).await?.into();
  109. let _ = self
  110. .user_status_callback
  111. .read()
  112. .await
  113. .did_sign_in(user_profile.id, &user_profile.workspace_id)
  114. .await;
  115. send_sign_in_notification()
  116. .payload::<UserProfilePB>(user_profile.clone().into())
  117. .send();
  118. Ok(user_profile)
  119. }
  120. #[tracing::instrument(level = "debug", skip(self, params))]
  121. pub async fn sign_up(
  122. &self,
  123. auth_type: &AuthType,
  124. params: BoxAny,
  125. ) -> Result<UserProfile, FlowyError> {
  126. self
  127. .user_status_callback
  128. .read()
  129. .await
  130. .auth_type_did_changed(auth_type.clone());
  131. self.cloud_services.set_auth_type(auth_type.clone());
  132. let resp = self
  133. .cloud_services
  134. .get_auth_service()?
  135. .sign_up(params)
  136. .await?;
  137. let session: Session = resp.clone().into();
  138. self.set_session(Some(session))?;
  139. let user_table = self.save_user(resp.into()).await?;
  140. let user_profile: UserProfile = user_table.into();
  141. let _ = self
  142. .user_status_callback
  143. .read()
  144. .await
  145. .did_sign_up(&user_profile)
  146. .await;
  147. Ok(user_profile)
  148. }
  149. #[tracing::instrument(level = "debug", skip(self))]
  150. pub async fn sign_out(&self, auth_type: &AuthType) -> Result<(), FlowyError> {
  151. let session = self.get_session()?;
  152. let uid = session.user_id.to_string();
  153. let _ = diesel::delete(dsl::user_table.filter(dsl::id.eq(&uid)))
  154. .execute(&*(self.db_connection()?))?;
  155. self.database.close_user_db(session.user_id)?;
  156. self.set_session(None)?;
  157. let server = self.cloud_services.get_auth_service()?;
  158. let token = session.token;
  159. let _ = tokio::spawn(async move {
  160. match server.sign_out(token).await {
  161. Ok(_) => {},
  162. Err(e) => tracing::error!("Sign out failed: {:?}", e),
  163. }
  164. });
  165. Ok(())
  166. }
  167. #[tracing::instrument(level = "debug", skip(self))]
  168. pub async fn update_user_profile(
  169. &self,
  170. params: UpdateUserProfileParams,
  171. ) -> Result<(), FlowyError> {
  172. let auth_type = params.auth_type.clone();
  173. let session = self.get_session()?;
  174. let changeset = UserTableChangeset::new(params.clone());
  175. diesel_update_table!(user_table, changeset, &*self.db_connection()?);
  176. let user_profile = self.get_user_profile().await?;
  177. let profile_pb: UserProfilePB = user_profile.into();
  178. send_notification(
  179. &session.user_id.to_string(),
  180. UserNotification::DidUpdateUserProfile,
  181. )
  182. .payload(profile_pb)
  183. .send();
  184. self
  185. .update_user(&auth_type, session.user_id, &session.token, params)
  186. .await?;
  187. Ok(())
  188. }
  189. pub async fn init_user(&self) -> Result<(), FlowyError> {
  190. Ok(())
  191. }
  192. pub async fn check_user(&self) -> Result<UserProfile, FlowyError> {
  193. let (user_id, _token) = self.get_session()?.into_part();
  194. let user_id = user_id.to_string();
  195. let user = dsl::user_table
  196. .filter(user_table::id.eq(&user_id))
  197. .first::<UserTable>(&*(self.db_connection()?))?;
  198. Ok(user.into())
  199. }
  200. pub async fn get_user_profile(&self) -> Result<UserProfile, FlowyError> {
  201. let (user_id, _) = self.get_session()?.into_part();
  202. let user_id = user_id.to_string();
  203. let user = dsl::user_table
  204. .filter(user_table::id.eq(&user_id))
  205. .first::<UserTable>(&*(self.db_connection()?))?;
  206. Ok(user.into())
  207. }
  208. pub fn user_dir(&self) -> Result<String, FlowyError> {
  209. let session = self.get_session()?;
  210. Ok(format!(
  211. "{}/{}",
  212. self.session_config.root_dir, session.user_id
  213. ))
  214. }
  215. pub fn user_setting(&self) -> Result<UserSettingPB, FlowyError> {
  216. let user_setting = UserSettingPB {
  217. user_folder: self.user_dir()?,
  218. };
  219. Ok(user_setting)
  220. }
  221. pub fn user_id(&self) -> Result<i64, FlowyError> {
  222. Ok(self.get_session()?.user_id)
  223. }
  224. pub fn user_name(&self) -> Result<String, FlowyError> {
  225. Ok(self.get_session()?.name)
  226. }
  227. pub fn token(&self) -> Result<Option<String>, FlowyError> {
  228. Ok(self.get_session()?.token)
  229. }
  230. }
  231. impl UserSession {
  232. async fn update_user(
  233. &self,
  234. _auth_type: &AuthType,
  235. uid: i64,
  236. token: &Option<String>,
  237. params: UpdateUserProfileParams,
  238. ) -> Result<(), FlowyError> {
  239. let server = self.cloud_services.get_auth_service()?;
  240. let token = token.to_owned();
  241. let _ = tokio::spawn(async move {
  242. match server.update_user(uid, &token, params).await {
  243. Ok(_) => {},
  244. Err(e) => {
  245. // TODO: retry?
  246. tracing::error!("update user profile failed: {:?}", e);
  247. },
  248. }
  249. })
  250. .await;
  251. Ok(())
  252. }
  253. async fn save_user(&self, user: UserTable) -> Result<UserTable, FlowyError> {
  254. let conn = self.db_connection()?;
  255. let _ = diesel::insert_into(user_table::table)
  256. .values(user.clone())
  257. .execute(&*conn)?;
  258. Ok(user)
  259. }
  260. fn set_session(&self, session: Option<Session>) -> Result<(), FlowyError> {
  261. tracing::debug!("Set user session: {:?}", session);
  262. match &session {
  263. None => KV::remove(&self.session_config.session_cache_key),
  264. Some(session) => {
  265. KV::set_object(&self.session_config.session_cache_key, session.clone())
  266. .map_err(internal_error)?;
  267. },
  268. }
  269. Ok(())
  270. }
  271. fn get_session(&self) -> Result<Session, FlowyError> {
  272. match KV::get_object::<Session>(&self.session_config.session_cache_key) {
  273. None => Err(FlowyError::unauthorized()),
  274. Some(session) => Ok(session),
  275. }
  276. }
  277. }
  278. pub async fn update_user(
  279. _cloud_service: Arc<dyn UserAuthService>,
  280. pool: Arc<ConnectionPool>,
  281. params: UpdateUserProfileParams,
  282. ) -> Result<(), FlowyError> {
  283. let changeset = UserTableChangeset::new(params);
  284. let conn = pool.get()?;
  285. diesel_update_table!(user_table, changeset, &*conn);
  286. Ok(())
  287. }
  288. impl UserDatabaseConnection for UserSession {
  289. fn get_connection(&self) -> Result<DBConnection, String> {
  290. self.db_connection().map_err(|e| format!("{:?}", e))
  291. }
  292. }
  293. #[derive(Debug, Clone, Default, Serialize, Deserialize)]
  294. struct Session {
  295. user_id: i64,
  296. workspace_id: String,
  297. #[serde(default)]
  298. name: String,
  299. #[serde(default)]
  300. token: Option<String>,
  301. #[serde(default)]
  302. email: Option<String>,
  303. }
  304. impl std::convert::From<SignInResponse> for Session {
  305. fn from(resp: SignInResponse) -> Self {
  306. Session {
  307. user_id: resp.user_id,
  308. token: resp.token,
  309. email: resp.email,
  310. name: resp.name,
  311. workspace_id: resp.workspace_id,
  312. }
  313. }
  314. }
  315. impl std::convert::From<SignUpResponse> for Session {
  316. fn from(resp: SignUpResponse) -> Self {
  317. Session {
  318. user_id: resp.user_id,
  319. token: resp.token,
  320. email: resp.email,
  321. name: resp.name,
  322. workspace_id: resp.workspace_id,
  323. }
  324. }
  325. }
  326. impl Session {
  327. pub fn into_part(self) -> (i64, Option<String>) {
  328. (self.user_id, self.token)
  329. }
  330. }
  331. impl std::convert::From<String> for Session {
  332. fn from(s: String) -> Self {
  333. match serde_json::from_str(&s) {
  334. Ok(s) => s,
  335. Err(e) => {
  336. tracing::error!("Deserialize string to Session failed: {:?}", e);
  337. Session::default()
  338. },
  339. }
  340. }
  341. }
  342. impl std::convert::From<Session> for String {
  343. fn from(session: Session) -> Self {
  344. match serde_json::to_string(&session) {
  345. Ok(s) => s,
  346. Err(e) => {
  347. tracing::error!("Serialize session to string failed: {:?}", e);
  348. "".to_string()
  349. },
  350. }
  351. }
  352. }
  353. #[derive(Debug, Clone, Hash, Serialize_repr, Deserialize_repr, Eq, PartialEq)]
  354. #[repr(u8)]
  355. pub enum AuthType {
  356. /// It's a local server, we do fake sign in default.
  357. Local = 0,
  358. /// Currently not supported. It will be supported in the future when the
  359. /// [AppFlowy-Server](https://github.com/AppFlowy-IO/AppFlowy-Server) ready.
  360. SelfHosted = 1,
  361. /// It uses Supabase as the backend.
  362. Supabase = 2,
  363. }
  364. impl Default for AuthType {
  365. fn default() -> Self {
  366. Self::Local
  367. }
  368. }
  369. impl From<AuthTypePB> for AuthType {
  370. fn from(pb: AuthTypePB) -> Self {
  371. match pb {
  372. AuthTypePB::Supabase => AuthType::Supabase,
  373. AuthTypePB::Local => AuthType::Local,
  374. AuthTypePB::SelfHosted => AuthType::SelfHosted,
  375. }
  376. }
  377. }