user_session.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. use crate::entities::{
  2. SignInParams, SignInResponse, SignUpParams, SignUpResponse, UpdateUserProfileParams, UserProfilePB, UserSettingPB,
  3. };
  4. use crate::{
  5. errors::{ErrorCode, FlowyError},
  6. event_map::UserCloudService,
  7. notification::*,
  8. services::{
  9. database::{UserDB, UserTable, UserTableChangeset},
  10. notifier::UserNotifier,
  11. },
  12. };
  13. use flowy_database::ConnectionPool;
  14. use flowy_database::{
  15. kv::KV,
  16. query_dsl::*,
  17. schema::{user_table, user_table::dsl},
  18. DBConnection, ExpressionMethods, UserDatabaseConnection,
  19. };
  20. use serde::{Deserialize, Serialize};
  21. use std::sync::Arc;
  22. use tokio::sync::mpsc;
  23. pub struct UserSessionConfig {
  24. root_dir: String,
  25. /// Used as the key of `Session` when saving session information to KV.
  26. session_cache_key: String,
  27. }
  28. impl UserSessionConfig {
  29. /// The `root_dir` represents as the root of the user folders. It must be unique for each
  30. /// users.
  31. pub fn new(name: &str, root_dir: &str) -> Self {
  32. let session_cache_key = format!("{}_session_cache", name);
  33. Self {
  34. root_dir: root_dir.to_owned(),
  35. session_cache_key,
  36. }
  37. }
  38. }
  39. pub struct UserSession {
  40. database: UserDB,
  41. config: UserSessionConfig,
  42. cloud_service: Arc<dyn UserCloudService>,
  43. pub notifier: UserNotifier,
  44. }
  45. impl UserSession {
  46. pub fn new(config: UserSessionConfig, cloud_service: Arc<dyn UserCloudService>) -> Self {
  47. let db = UserDB::new(&config.root_dir);
  48. let notifier = UserNotifier::new();
  49. Self {
  50. database: db,
  51. config,
  52. cloud_service,
  53. notifier,
  54. }
  55. }
  56. pub fn init(&self) {
  57. if let Ok(session) = self.get_session() {
  58. self.notifier.notify_login(&session.token, &session.user_id);
  59. }
  60. }
  61. pub fn db_connection(&self) -> Result<DBConnection, FlowyError> {
  62. let user_id = self.get_session()?.user_id;
  63. self.database.get_connection(&user_id)
  64. }
  65. // The caller will be not 'Sync' before of the return value,
  66. // PooledConnection<ConnectionManager> is not sync. You can use
  67. // db_connection_pool function to require the ConnectionPool that is 'Sync'.
  68. //
  69. // let pool = self.db_connection_pool()?;
  70. // let conn: PooledConnection<ConnectionManager> = pool.get()?;
  71. pub fn db_pool(&self) -> Result<Arc<ConnectionPool>, FlowyError> {
  72. let user_id = self.get_session()?.user_id;
  73. self.database.get_pool(&user_id)
  74. }
  75. #[tracing::instrument(level = "debug", skip(self))]
  76. pub async fn sign_in(&self, params: SignInParams) -> Result<UserProfilePB, FlowyError> {
  77. if self.is_user_login(&params.email) {
  78. self.get_user_profile().await
  79. } else {
  80. let resp = self.cloud_service.sign_in(params).await?;
  81. let session: Session = resp.clone().into();
  82. self.set_session(Some(session))?;
  83. let user_table = self.save_user(resp.into()).await?;
  84. let user_profile: UserProfilePB = user_table.into();
  85. self.notifier.notify_login(&user_profile.token, &user_profile.id);
  86. Ok(user_profile)
  87. }
  88. }
  89. #[tracing::instrument(level = "debug", skip(self))]
  90. pub async fn sign_up(&self, params: SignUpParams) -> Result<UserProfilePB, FlowyError> {
  91. if self.is_user_login(&params.email) {
  92. self.get_user_profile().await
  93. } else {
  94. let resp = self.cloud_service.sign_up(params).await?;
  95. let session: Session = resp.clone().into();
  96. self.set_session(Some(session))?;
  97. let user_table = self.save_user(resp.into()).await?;
  98. let user_profile: UserProfilePB = user_table.into();
  99. let (ret, mut tx) = mpsc::channel(1);
  100. self.notifier.notify_sign_up(ret, &user_profile);
  101. let _ = tx.recv().await;
  102. Ok(user_profile)
  103. }
  104. }
  105. #[tracing::instrument(level = "debug", skip(self))]
  106. pub async fn sign_out(&self) -> Result<(), FlowyError> {
  107. let session = self.get_session()?;
  108. let _ =
  109. diesel::delete(dsl::user_table.filter(dsl::id.eq(&session.user_id))).execute(&*(self.db_connection()?))?;
  110. self.database.close_user_db(&session.user_id)?;
  111. self.set_session(None)?;
  112. self.notifier.notify_logout(&session.token, &session.user_id);
  113. self.sign_out_on_server(&session.token).await?;
  114. Ok(())
  115. }
  116. #[tracing::instrument(level = "debug", skip(self))]
  117. pub async fn update_user_profile(&self, params: UpdateUserProfileParams) -> Result<(), FlowyError> {
  118. let session = self.get_session()?;
  119. let changeset = UserTableChangeset::new(params.clone());
  120. diesel_update_table!(user_table, changeset, &*self.db_connection()?);
  121. let user_profile = self.get_user_profile().await?;
  122. dart_notify(&session.token, UserNotification::UserProfileUpdated)
  123. .payload(user_profile)
  124. .send();
  125. self.update_user_on_server(&session.token, params).await?;
  126. Ok(())
  127. }
  128. pub async fn init_user(&self) -> Result<(), FlowyError> {
  129. Ok(())
  130. }
  131. pub async fn check_user(&self) -> Result<UserProfilePB, FlowyError> {
  132. let (user_id, token) = self.get_session()?.into_part();
  133. let user = dsl::user_table
  134. .filter(user_table::id.eq(&user_id))
  135. .first::<UserTable>(&*(self.db_connection()?))?;
  136. self.read_user_profile_on_server(&token)?;
  137. Ok(user.into())
  138. }
  139. pub async fn get_user_profile(&self) -> Result<UserProfilePB, FlowyError> {
  140. let (user_id, token) = self.get_session()?.into_part();
  141. let user = dsl::user_table
  142. .filter(user_table::id.eq(&user_id))
  143. .first::<UserTable>(&*(self.db_connection()?))?;
  144. self.read_user_profile_on_server(&token)?;
  145. Ok(user.into())
  146. }
  147. pub fn user_dir(&self) -> Result<String, FlowyError> {
  148. let session = self.get_session()?;
  149. Ok(format!("{}/{}", self.config.root_dir, session.user_id))
  150. }
  151. pub fn user_setting(&self) -> Result<UserSettingPB, FlowyError> {
  152. let user_setting = UserSettingPB {
  153. user_folder: self.user_dir()?,
  154. };
  155. Ok(user_setting)
  156. }
  157. pub fn user_id(&self) -> Result<String, FlowyError> {
  158. Ok(self.get_session()?.user_id)
  159. }
  160. pub fn user_name(&self) -> Result<String, FlowyError> {
  161. Ok(self.get_session()?.name)
  162. }
  163. pub fn token(&self) -> Result<String, FlowyError> {
  164. Ok(self.get_session()?.token)
  165. }
  166. }
  167. impl UserSession {
  168. fn read_user_profile_on_server(&self, _token: &str) -> Result<(), FlowyError> {
  169. // let server = self.cloud_service.clone();
  170. // let token = token.to_owned();
  171. // tokio::spawn(async move {
  172. // match server.get_user(&token).await {
  173. // Ok(profile) => {
  174. // dart_notify(&token, UserNotification::UserProfileUpdated)
  175. // .payload(profile)
  176. // .send();
  177. // }
  178. // Err(e) => {
  179. // dart_notify(&token, UserNotification::UserProfileUpdated)
  180. // .error(e)
  181. // .send();
  182. // }
  183. // }
  184. // });
  185. Ok(())
  186. }
  187. async fn update_user_on_server(&self, token: &str, params: UpdateUserProfileParams) -> Result<(), FlowyError> {
  188. let server = self.cloud_service.clone();
  189. let token = token.to_owned();
  190. let _ = tokio::spawn(async move {
  191. match server.update_user(&token, params).await {
  192. Ok(_) => {}
  193. Err(e) => {
  194. // TODO: retry?
  195. log::error!("update user profile failed: {:?}", e);
  196. }
  197. }
  198. })
  199. .await;
  200. Ok(())
  201. }
  202. async fn sign_out_on_server(&self, token: &str) -> Result<(), FlowyError> {
  203. let server = self.cloud_service.clone();
  204. let token = token.to_owned();
  205. let _ = tokio::spawn(async move {
  206. match server.sign_out(&token).await {
  207. Ok(_) => {}
  208. Err(e) => log::error!("Sign out failed: {:?}", e),
  209. }
  210. })
  211. .await;
  212. Ok(())
  213. }
  214. async fn save_user(&self, user: UserTable) -> Result<UserTable, FlowyError> {
  215. let conn = self.db_connection()?;
  216. let _ = diesel::insert_into(user_table::table)
  217. .values(user.clone())
  218. .execute(&*conn)?;
  219. Ok(user)
  220. }
  221. fn set_session(&self, session: Option<Session>) -> Result<(), FlowyError> {
  222. tracing::debug!("Set user session: {:?}", session);
  223. match &session {
  224. None => KV::remove(&self.config.session_cache_key).map_err(|e| FlowyError::new(ErrorCode::Internal, &e))?,
  225. Some(session) => KV::set_str(&self.config.session_cache_key, session.clone().into()),
  226. }
  227. Ok(())
  228. }
  229. fn get_session(&self) -> Result<Session, FlowyError> {
  230. match KV::get_str(&self.config.session_cache_key) {
  231. None => Err(FlowyError::unauthorized()),
  232. Some(s) => Ok(Session::from(s)),
  233. }
  234. }
  235. fn is_user_login(&self, email: &str) -> bool {
  236. match self.get_session() {
  237. Ok(session) => session.email == email,
  238. Err(_) => false,
  239. }
  240. }
  241. }
  242. pub async fn update_user(
  243. _cloud_service: Arc<dyn UserCloudService>,
  244. pool: Arc<ConnectionPool>,
  245. params: UpdateUserProfileParams,
  246. ) -> Result<(), FlowyError> {
  247. let changeset = UserTableChangeset::new(params);
  248. let conn = pool.get()?;
  249. diesel_update_table!(user_table, changeset, &*conn);
  250. Ok(())
  251. }
  252. impl UserDatabaseConnection for UserSession {
  253. fn get_connection(&self) -> Result<DBConnection, String> {
  254. self.db_connection().map_err(|e| format!("{:?}", e))
  255. }
  256. }
  257. #[derive(Debug, Clone, Default, Serialize, Deserialize)]
  258. struct Session {
  259. user_id: String,
  260. token: String,
  261. email: String,
  262. #[serde(default)]
  263. name: String,
  264. }
  265. impl std::convert::From<SignInResponse> for Session {
  266. fn from(resp: SignInResponse) -> Self {
  267. Session {
  268. user_id: resp.user_id,
  269. token: resp.token,
  270. email: resp.email,
  271. name: resp.name,
  272. }
  273. }
  274. }
  275. impl std::convert::From<SignUpResponse> for Session {
  276. fn from(resp: SignUpResponse) -> Self {
  277. Session {
  278. user_id: resp.user_id,
  279. token: resp.token,
  280. email: resp.email,
  281. name: resp.name,
  282. }
  283. }
  284. }
  285. impl Session {
  286. pub fn into_part(self) -> (String, String) {
  287. (self.user_id, self.token)
  288. }
  289. }
  290. impl std::convert::From<String> for Session {
  291. fn from(s: String) -> Self {
  292. match serde_json::from_str(&s) {
  293. Ok(s) => s,
  294. Err(e) => {
  295. log::error!("Deserialize string to Session failed: {:?}", e);
  296. Session::default()
  297. }
  298. }
  299. }
  300. }
  301. impl std::convert::From<Session> for String {
  302. fn from(session: Session) -> Self {
  303. match serde_json::to_string(&session) {
  304. Ok(s) => s,
  305. Err(e) => {
  306. log::error!("Serialize session to string failed: {:?}", e);
  307. "".to_string()
  308. }
  309. }
  310. }
  311. }