user_session.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. use crate::{
  2. dart_notification::*,
  3. errors::{ErrorCode, FlowyError},
  4. event_map::UserCloudService,
  5. services::{
  6. database::{UserDB, UserTable, UserTableChangeset},
  7. notifier::UserNotifier,
  8. },
  9. };
  10. use flowy_database::{
  11. kv::KV,
  12. query_dsl::*,
  13. schema::{user_table, user_table::dsl},
  14. DBConnection, ExpressionMethods, UserDatabaseConnection,
  15. };
  16. use flowy_user_data_model::entities::{
  17. SignInParams, SignInResponse, SignUpParams, SignUpResponse, UpdateUserParams, UserProfile,
  18. };
  19. use lib_sqlite::ConnectionPool;
  20. use parking_lot::RwLock;
  21. use serde::{Deserialize, Serialize};
  22. use std::sync::Arc;
  23. use tokio::sync::mpsc;
  24. pub struct UserSessionConfig {
  25. root_dir: String,
  26. session_cache_key: String,
  27. }
  28. impl UserSessionConfig {
  29. pub fn new(root_dir: &str, session_cache_key: &str) -> Self {
  30. Self {
  31. root_dir: root_dir.to_owned(),
  32. session_cache_key: session_cache_key.to_owned(),
  33. }
  34. }
  35. }
  36. pub struct UserSession {
  37. database: UserDB,
  38. config: UserSessionConfig,
  39. cloud_service: Arc<dyn UserCloudService>,
  40. session: RwLock<Option<Session>>,
  41. pub notifier: UserNotifier,
  42. }
  43. impl UserSession {
  44. pub fn new(config: UserSessionConfig, cloud_service: Arc<dyn UserCloudService>) -> Self {
  45. let db = UserDB::new(&config.root_dir);
  46. let notifier = UserNotifier::new();
  47. Self {
  48. database: db,
  49. config,
  50. cloud_service,
  51. session: RwLock::new(None),
  52. notifier,
  53. }
  54. }
  55. pub fn init(&self) {
  56. if let Ok(session) = self.get_session() {
  57. self.notifier.notify_login(&session.token, &session.user_id);
  58. }
  59. }
  60. pub fn db_connection(&self) -> Result<DBConnection, FlowyError> {
  61. let user_id = self.get_session()?.user_id;
  62. self.database.get_connection(&user_id)
  63. }
  64. // The caller will be not 'Sync' before of the return value,
  65. // PooledConnection<ConnectionManager> is not sync. You can use
  66. // db_connection_pool function to require the ConnectionPool that is 'Sync'.
  67. //
  68. // let pool = self.db_connection_pool()?;
  69. // let conn: PooledConnection<ConnectionManager> = pool.get()?;
  70. pub fn db_pool(&self) -> Result<Arc<ConnectionPool>, FlowyError> {
  71. let user_id = self.get_session()?.user_id;
  72. self.database.get_pool(&user_id)
  73. }
  74. #[tracing::instrument(level = "debug", skip(self))]
  75. pub async fn sign_in(&self, params: SignInParams) -> Result<UserProfile, FlowyError> {
  76. if self.is_user_login(&params.email) {
  77. self.user_profile().await
  78. } else {
  79. let resp = self.cloud_service.sign_in(params).await?;
  80. let session: Session = resp.clone().into();
  81. let _ = self.set_session(Some(session))?;
  82. let user_table = self.save_user(resp.into()).await?;
  83. let user_profile: UserProfile = user_table.into();
  84. self.notifier.notify_login(&user_profile.token, &user_profile.id);
  85. Ok(user_profile)
  86. }
  87. }
  88. #[tracing::instrument(level = "debug", skip(self))]
  89. pub async fn sign_up(&self, params: SignUpParams) -> Result<UserProfile, FlowyError> {
  90. if self.is_user_login(&params.email) {
  91. self.user_profile().await
  92. } else {
  93. let resp = self.cloud_service.sign_up(params).await?;
  94. let session: Session = resp.clone().into();
  95. let _ = self.set_session(Some(session))?;
  96. let user_table = self.save_user(resp.into()).await?;
  97. let user_profile: UserProfile = user_table.into();
  98. let (ret, mut tx) = mpsc::channel(1);
  99. self.notifier.notify_sign_up(ret, &user_profile);
  100. let _ = tx.recv().await;
  101. Ok(user_profile)
  102. }
  103. }
  104. #[tracing::instrument(level = "debug", skip(self))]
  105. pub async fn sign_out(&self) -> Result<(), FlowyError> {
  106. let session = self.get_session()?;
  107. let _ =
  108. diesel::delete(dsl::user_table.filter(dsl::id.eq(&session.user_id))).execute(&*(self.db_connection()?))?;
  109. let _ = self.database.close_user_db(&session.user_id)?;
  110. let _ = self.set_session(None)?;
  111. self.notifier.notify_logout(&session.token);
  112. let _ = self.sign_out_on_server(&session.token).await?;
  113. Ok(())
  114. }
  115. #[tracing::instrument(level = "debug", skip(self))]
  116. pub async fn update_user(&self, params: UpdateUserParams) -> Result<(), FlowyError> {
  117. let session = self.get_session()?;
  118. let changeset = UserTableChangeset::new(params.clone());
  119. diesel_update_table!(user_table, changeset, &*self.db_connection()?);
  120. let _ = self.update_user_on_server(&session.token, params).await?;
  121. Ok(())
  122. }
  123. pub async fn init_user(&self) -> Result<(), FlowyError> {
  124. Ok(())
  125. }
  126. pub async fn check_user(&self) -> Result<UserProfile, FlowyError> {
  127. let (user_id, token) = self.get_session()?.into_part();
  128. let user = dsl::user_table
  129. .filter(user_table::id.eq(&user_id))
  130. .first::<UserTable>(&*(self.db_connection()?))?;
  131. let _ = self.read_user_profile_on_server(&token)?;
  132. Ok(user.into())
  133. }
  134. pub async fn user_profile(&self) -> Result<UserProfile, FlowyError> {
  135. let (user_id, token) = self.get_session()?.into_part();
  136. let user = dsl::user_table
  137. .filter(user_table::id.eq(&user_id))
  138. .first::<UserTable>(&*(self.db_connection()?))?;
  139. let _ = self.read_user_profile_on_server(&token)?;
  140. Ok(user.into())
  141. }
  142. pub fn user_dir(&self) -> Result<String, FlowyError> {
  143. let session = self.get_session()?;
  144. Ok(format!("{}/{}", self.config.root_dir, session.user_id))
  145. }
  146. pub fn user_id(&self) -> Result<String, FlowyError> {
  147. Ok(self.get_session()?.user_id)
  148. }
  149. pub fn user_name(&self) -> Result<String, FlowyError> {
  150. Ok(self.get_session()?.name)
  151. }
  152. pub fn token(&self) -> Result<String, FlowyError> {
  153. Ok(self.get_session()?.token)
  154. }
  155. }
  156. impl UserSession {
  157. fn read_user_profile_on_server(&self, token: &str) -> Result<(), FlowyError> {
  158. let server = self.cloud_service.clone();
  159. let token = token.to_owned();
  160. tokio::spawn(async move {
  161. match server.get_user(&token).await {
  162. Ok(profile) => {
  163. dart_notify(&token, UserNotification::UserProfileUpdated)
  164. .payload(profile)
  165. .send();
  166. }
  167. Err(e) => {
  168. dart_notify(&token, UserNotification::UserProfileUpdated)
  169. .error(e)
  170. .send();
  171. }
  172. }
  173. });
  174. Ok(())
  175. }
  176. async fn update_user_on_server(&self, token: &str, params: UpdateUserParams) -> Result<(), FlowyError> {
  177. let server = self.cloud_service.clone();
  178. let token = token.to_owned();
  179. let _ = tokio::spawn(async move {
  180. match server.update_user(&token, params).await {
  181. Ok(_) => {}
  182. Err(e) => {
  183. // TODO: retry?
  184. log::error!("update user profile failed: {:?}", e);
  185. }
  186. }
  187. })
  188. .await;
  189. Ok(())
  190. }
  191. async fn sign_out_on_server(&self, token: &str) -> Result<(), FlowyError> {
  192. let server = self.cloud_service.clone();
  193. let token = token.to_owned();
  194. let _ = tokio::spawn(async move {
  195. match server.sign_out(&token).await {
  196. Ok(_) => {}
  197. Err(e) => log::error!("Sign out failed: {:?}", e),
  198. }
  199. })
  200. .await;
  201. Ok(())
  202. }
  203. async fn save_user(&self, user: UserTable) -> Result<UserTable, FlowyError> {
  204. let conn = self.db_connection()?;
  205. let _ = diesel::insert_into(user_table::table)
  206. .values(user.clone())
  207. .execute(&*conn)?;
  208. Ok(user)
  209. }
  210. fn set_session(&self, session: Option<Session>) -> Result<(), FlowyError> {
  211. tracing::debug!("Set user session: {:?}", session);
  212. match &session {
  213. None => KV::remove(&self.config.session_cache_key).map_err(|e| FlowyError::new(ErrorCode::Internal, &e))?,
  214. Some(session) => KV::set_str(&self.config.session_cache_key, session.clone().into()),
  215. }
  216. *self.session.write() = session;
  217. Ok(())
  218. }
  219. fn get_session(&self) -> Result<Session, FlowyError> {
  220. let mut session = { (*self.session.read()).clone() };
  221. if session.is_none() {
  222. match KV::get_str(&self.config.session_cache_key) {
  223. None => {}
  224. Some(s) => {
  225. session = Some(Session::from(s));
  226. let _ = self.set_session(session.clone())?;
  227. }
  228. }
  229. }
  230. match session {
  231. None => Err(FlowyError::unauthorized()),
  232. Some(session) => Ok(session),
  233. }
  234. }
  235. fn is_user_login(&self, email: &str) -> bool {
  236. match self.get_session() {
  237. Ok(session) => session.email == email,
  238. Err(_) => false,
  239. }
  240. }
  241. }
  242. pub async fn update_user(
  243. _cloud_service: Arc<dyn UserCloudService>,
  244. pool: Arc<ConnectionPool>,
  245. params: UpdateUserParams,
  246. ) -> Result<(), FlowyError> {
  247. let changeset = UserTableChangeset::new(params);
  248. let conn = pool.get()?;
  249. diesel_update_table!(user_table, changeset, &*conn);
  250. Ok(())
  251. }
  252. impl UserDatabaseConnection for UserSession {
  253. fn get_connection(&self) -> Result<DBConnection, String> {
  254. self.db_connection().map_err(|e| format!("{:?}", e))
  255. }
  256. }
  257. #[derive(Debug, Clone, Default, Serialize, Deserialize)]
  258. struct Session {
  259. user_id: String,
  260. token: String,
  261. email: String,
  262. #[serde(default)]
  263. name: String,
  264. }
  265. impl std::convert::From<SignInResponse> for Session {
  266. fn from(resp: SignInResponse) -> Self {
  267. Session {
  268. user_id: resp.user_id,
  269. token: resp.token,
  270. email: resp.email,
  271. name: resp.name,
  272. }
  273. }
  274. }
  275. impl std::convert::From<SignUpResponse> for Session {
  276. fn from(resp: SignUpResponse) -> Self {
  277. Session {
  278. user_id: resp.user_id,
  279. token: resp.token,
  280. email: resp.email,
  281. name: resp.name,
  282. }
  283. }
  284. }
  285. impl Session {
  286. pub fn into_part(self) -> (String, String) {
  287. (self.user_id, self.token)
  288. }
  289. }
  290. impl std::convert::From<String> for Session {
  291. fn from(s: String) -> Self {
  292. match serde_json::from_str(&s) {
  293. Ok(s) => s,
  294. Err(e) => {
  295. log::error!("Deserialize string to Session failed: {:?}", e);
  296. Session::default()
  297. }
  298. }
  299. }
  300. }
  301. impl std::convert::From<Session> for String {
  302. fn from(session: Session) -> Self {
  303. match serde_json::to_string(&session) {
  304. Ok(s) => s,
  305. Err(e) => {
  306. log::error!("Serialize session to string failed: {:?}", e);
  307. "".to_string()
  308. }
  309. }
  310. }
  311. }