user_session.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. use crate::entities::{
  2. SignInParams, SignInResponse, SignUpParams, SignUpResponse, UpdateUserProfileParams, UserProfile,
  3. };
  4. use crate::{
  5. dart_notification::*,
  6. errors::{ErrorCode, FlowyError},
  7. event_map::UserCloudService,
  8. services::{
  9. database::{UserDB, UserTable, UserTableChangeset},
  10. notifier::UserNotifier,
  11. },
  12. };
  13. use flowy_database::ConnectionPool;
  14. use flowy_database::{
  15. kv::KV,
  16. query_dsl::*,
  17. schema::{user_table, user_table::dsl},
  18. DBConnection, ExpressionMethods, UserDatabaseConnection,
  19. };
  20. use parking_lot::RwLock;
  21. use serde::{Deserialize, Serialize};
  22. use std::sync::Arc;
  23. use tokio::sync::mpsc;
  24. pub struct UserSessionConfig {
  25. root_dir: String,
  26. session_cache_key: String,
  27. }
  28. impl UserSessionConfig {
  29. pub fn new(root_dir: &str, session_cache_key: &str) -> Self {
  30. Self {
  31. root_dir: root_dir.to_owned(),
  32. session_cache_key: session_cache_key.to_owned(),
  33. }
  34. }
  35. }
  36. pub struct UserSession {
  37. database: UserDB,
  38. config: UserSessionConfig,
  39. cloud_service: Arc<dyn UserCloudService>,
  40. session: RwLock<Option<Session>>,
  41. pub notifier: UserNotifier,
  42. }
  43. impl UserSession {
  44. pub fn new(config: UserSessionConfig, cloud_service: Arc<dyn UserCloudService>) -> Self {
  45. let db = UserDB::new(&config.root_dir);
  46. let notifier = UserNotifier::new();
  47. Self {
  48. database: db,
  49. config,
  50. cloud_service,
  51. session: RwLock::new(None),
  52. notifier,
  53. }
  54. }
  55. pub fn init(&self) {
  56. if let Ok(session) = self.get_session() {
  57. self.notifier.notify_login(&session.token, &session.user_id);
  58. }
  59. }
  60. pub fn db_connection(&self) -> Result<DBConnection, FlowyError> {
  61. let user_id = self.get_session()?.user_id;
  62. self.database.get_connection(&user_id)
  63. }
  64. // The caller will be not 'Sync' before of the return value,
  65. // PooledConnection<ConnectionManager> is not sync. You can use
  66. // db_connection_pool function to require the ConnectionPool that is 'Sync'.
  67. //
  68. // let pool = self.db_connection_pool()?;
  69. // let conn: PooledConnection<ConnectionManager> = pool.get()?;
  70. pub fn db_pool(&self) -> Result<Arc<ConnectionPool>, FlowyError> {
  71. let user_id = self.get_session()?.user_id;
  72. self.database.get_pool(&user_id)
  73. }
  74. #[tracing::instrument(level = "debug", skip(self))]
  75. pub async fn sign_in(&self, params: SignInParams) -> Result<UserProfile, FlowyError> {
  76. if self.is_user_login(&params.email) {
  77. self.get_user_profile().await
  78. } else {
  79. let resp = self.cloud_service.sign_in(params).await?;
  80. let session: Session = resp.clone().into();
  81. let _ = self.set_session(Some(session))?;
  82. let user_table = self.save_user(resp.into()).await?;
  83. let user_profile: UserProfile = user_table.into();
  84. self.notifier.notify_login(&user_profile.token, &user_profile.id);
  85. Ok(user_profile)
  86. }
  87. }
  88. #[tracing::instrument(level = "debug", skip(self))]
  89. pub async fn sign_up(&self, params: SignUpParams) -> Result<UserProfile, FlowyError> {
  90. if self.is_user_login(&params.email) {
  91. self.get_user_profile().await
  92. } else {
  93. let resp = self.cloud_service.sign_up(params).await?;
  94. let session: Session = resp.clone().into();
  95. let _ = self.set_session(Some(session))?;
  96. let user_table = self.save_user(resp.into()).await?;
  97. let user_profile: UserProfile = user_table.into();
  98. let (ret, mut tx) = mpsc::channel(1);
  99. self.notifier.notify_sign_up(ret, &user_profile);
  100. let _ = tx.recv().await;
  101. Ok(user_profile)
  102. }
  103. }
  104. #[tracing::instrument(level = "debug", skip(self))]
  105. pub async fn sign_out(&self) -> Result<(), FlowyError> {
  106. let session = self.get_session()?;
  107. let _ =
  108. diesel::delete(dsl::user_table.filter(dsl::id.eq(&session.user_id))).execute(&*(self.db_connection()?))?;
  109. let _ = self.database.close_user_db(&session.user_id)?;
  110. let _ = self.set_session(None)?;
  111. self.notifier.notify_logout(&session.token);
  112. let _ = self.sign_out_on_server(&session.token).await?;
  113. Ok(())
  114. }
  115. #[tracing::instrument(level = "debug", skip(self))]
  116. pub async fn update_user_profile(&self, params: UpdateUserProfileParams) -> Result<(), FlowyError> {
  117. let session = self.get_session()?;
  118. let changeset = UserTableChangeset::new(params.clone());
  119. diesel_update_table!(user_table, changeset, &*self.db_connection()?);
  120. let user_profile = self.get_user_profile().await?;
  121. dart_notify(&session.token, UserNotification::UserProfileUpdated)
  122. .payload(user_profile)
  123. .send();
  124. let _ = self.update_user_on_server(&session.token, params).await?;
  125. Ok(())
  126. }
  127. pub async fn init_user(&self) -> Result<(), FlowyError> {
  128. Ok(())
  129. }
  130. pub async fn check_user(&self) -> Result<UserProfile, FlowyError> {
  131. let (user_id, token) = self.get_session()?.into_part();
  132. let user = dsl::user_table
  133. .filter(user_table::id.eq(&user_id))
  134. .first::<UserTable>(&*(self.db_connection()?))?;
  135. let _ = self.read_user_profile_on_server(&token)?;
  136. Ok(user.into())
  137. }
  138. pub async fn get_user_profile(&self) -> Result<UserProfile, FlowyError> {
  139. let (user_id, token) = self.get_session()?.into_part();
  140. let user = dsl::user_table
  141. .filter(user_table::id.eq(&user_id))
  142. .first::<UserTable>(&*(self.db_connection()?))?;
  143. let _ = self.read_user_profile_on_server(&token)?;
  144. Ok(user.into())
  145. }
  146. pub fn user_dir(&self) -> Result<String, FlowyError> {
  147. let session = self.get_session()?;
  148. Ok(format!("{}/{}", self.config.root_dir, session.user_id))
  149. }
  150. pub fn user_id(&self) -> Result<String, FlowyError> {
  151. Ok(self.get_session()?.user_id)
  152. }
  153. pub fn user_name(&self) -> Result<String, FlowyError> {
  154. Ok(self.get_session()?.name)
  155. }
  156. pub fn token(&self) -> Result<String, FlowyError> {
  157. Ok(self.get_session()?.token)
  158. }
  159. }
  160. impl UserSession {
  161. fn read_user_profile_on_server(&self, _token: &str) -> Result<(), FlowyError> {
  162. // let server = self.cloud_service.clone();
  163. // let token = token.to_owned();
  164. // tokio::spawn(async move {
  165. // match server.get_user(&token).await {
  166. // Ok(profile) => {
  167. // dart_notify(&token, UserNotification::UserProfileUpdated)
  168. // .payload(profile)
  169. // .send();
  170. // }
  171. // Err(e) => {
  172. // dart_notify(&token, UserNotification::UserProfileUpdated)
  173. // .error(e)
  174. // .send();
  175. // }
  176. // }
  177. // });
  178. Ok(())
  179. }
  180. async fn update_user_on_server(&self, token: &str, params: UpdateUserProfileParams) -> Result<(), FlowyError> {
  181. let server = self.cloud_service.clone();
  182. let token = token.to_owned();
  183. let _ = tokio::spawn(async move {
  184. match server.update_user(&token, params).await {
  185. Ok(_) => {}
  186. Err(e) => {
  187. // TODO: retry?
  188. log::error!("update user profile failed: {:?}", e);
  189. }
  190. }
  191. })
  192. .await;
  193. Ok(())
  194. }
  195. async fn sign_out_on_server(&self, token: &str) -> Result<(), FlowyError> {
  196. let server = self.cloud_service.clone();
  197. let token = token.to_owned();
  198. let _ = tokio::spawn(async move {
  199. match server.sign_out(&token).await {
  200. Ok(_) => {}
  201. Err(e) => log::error!("Sign out failed: {:?}", e),
  202. }
  203. })
  204. .await;
  205. Ok(())
  206. }
  207. async fn save_user(&self, user: UserTable) -> Result<UserTable, FlowyError> {
  208. let conn = self.db_connection()?;
  209. let _ = diesel::insert_into(user_table::table)
  210. .values(user.clone())
  211. .execute(&*conn)?;
  212. Ok(user)
  213. }
  214. fn set_session(&self, session: Option<Session>) -> Result<(), FlowyError> {
  215. tracing::debug!("Set user session: {:?}", session);
  216. match &session {
  217. None => KV::remove(&self.config.session_cache_key).map_err(|e| FlowyError::new(ErrorCode::Internal, &e))?,
  218. Some(session) => KV::set_str(&self.config.session_cache_key, session.clone().into()),
  219. }
  220. *self.session.write() = session;
  221. Ok(())
  222. }
  223. fn get_session(&self) -> Result<Session, FlowyError> {
  224. let mut session = { (*self.session.read()).clone() };
  225. if session.is_none() {
  226. match KV::get_str(&self.config.session_cache_key) {
  227. None => {}
  228. Some(s) => {
  229. session = Some(Session::from(s));
  230. let _ = self.set_session(session.clone())?;
  231. }
  232. }
  233. }
  234. match session {
  235. None => Err(FlowyError::unauthorized()),
  236. Some(session) => Ok(session),
  237. }
  238. }
  239. fn is_user_login(&self, email: &str) -> bool {
  240. match self.get_session() {
  241. Ok(session) => session.email == email,
  242. Err(_) => false,
  243. }
  244. }
  245. }
  246. pub async fn update_user(
  247. _cloud_service: Arc<dyn UserCloudService>,
  248. pool: Arc<ConnectionPool>,
  249. params: UpdateUserProfileParams,
  250. ) -> Result<(), FlowyError> {
  251. let changeset = UserTableChangeset::new(params);
  252. let conn = pool.get()?;
  253. diesel_update_table!(user_table, changeset, &*conn);
  254. Ok(())
  255. }
  256. impl UserDatabaseConnection for UserSession {
  257. fn get_connection(&self) -> Result<DBConnection, String> {
  258. self.db_connection().map_err(|e| format!("{:?}", e))
  259. }
  260. }
  261. #[derive(Debug, Clone, Default, Serialize, Deserialize)]
  262. struct Session {
  263. user_id: String,
  264. token: String,
  265. email: String,
  266. #[serde(default)]
  267. name: String,
  268. }
  269. impl std::convert::From<SignInResponse> for Session {
  270. fn from(resp: SignInResponse) -> Self {
  271. Session {
  272. user_id: resp.user_id,
  273. token: resp.token,
  274. email: resp.email,
  275. name: resp.name,
  276. }
  277. }
  278. }
  279. impl std::convert::From<SignUpResponse> for Session {
  280. fn from(resp: SignUpResponse) -> Self {
  281. Session {
  282. user_id: resp.user_id,
  283. token: resp.token,
  284. email: resp.email,
  285. name: resp.name,
  286. }
  287. }
  288. }
  289. impl Session {
  290. pub fn into_part(self) -> (String, String) {
  291. (self.user_id, self.token)
  292. }
  293. }
  294. impl std::convert::From<String> for Session {
  295. fn from(s: String) -> Self {
  296. match serde_json::from_str(&s) {
  297. Ok(s) => s,
  298. Err(e) => {
  299. log::error!("Deserialize string to Session failed: {:?}", e);
  300. Session::default()
  301. }
  302. }
  303. }
  304. }
  305. impl std::convert::From<Session> for String {
  306. fn from(session: Session) -> Self {
  307. match serde_json::to_string(&session) {
  308. Ok(s) => s,
  309. Err(e) => {
  310. log::error!("Serialize session to string failed: {:?}", e);
  311. "".to_string()
  312. }
  313. }
  314. }
  315. }