user_session.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. use parking_lot::RwLock;
  2. use serde::{Deserialize, Serialize};
  3. use std::sync::Arc;
  4. use tokio::sync::mpsc;
  5. use backend_service::configuration::ClientServerConfiguration;
  6. use flowy_database::{
  7. kv::KV,
  8. query_dsl::*,
  9. schema::{user_table, user_table::dsl},
  10. DBConnection, ExpressionMethods, UserDatabaseConnection,
  11. };
  12. use flowy_user_data_model::entities::{SignInResponse, SignUpResponse};
  13. use lib_sqlite::ConnectionPool;
  14. use crate::{
  15. entities::{SignInParams, SignUpParams, UpdateUserParams, UserProfile},
  16. errors::{ErrorCode, FlowyError},
  17. notify::*,
  18. services::{
  19. server::{construct_user_server, Server},
  20. user::{database::UserDB, notifier::UserNotifier},
  21. },
  22. sql_tables::{UserTable, UserTableChangeset},
  23. };
  24. pub struct UserSessionConfig {
  25. root_dir: String,
  26. server_config: ClientServerConfiguration,
  27. session_cache_key: String,
  28. }
  29. impl UserSessionConfig {
  30. pub fn new(root_dir: &str, server_config: &ClientServerConfiguration, session_cache_key: &str) -> Self {
  31. Self {
  32. root_dir: root_dir.to_owned(),
  33. server_config: server_config.clone(),
  34. session_cache_key: session_cache_key.to_owned(),
  35. }
  36. }
  37. }
  38. pub struct UserSession {
  39. database: UserDB,
  40. config: UserSessionConfig,
  41. #[allow(dead_code)]
  42. server: Server,
  43. session: RwLock<Option<Session>>,
  44. pub notifier: UserNotifier,
  45. }
  46. impl UserSession {
  47. pub fn new(config: UserSessionConfig) -> Self {
  48. let db = UserDB::new(&config.root_dir);
  49. let server = construct_user_server(&config.server_config);
  50. let notifier = UserNotifier::new();
  51. Self {
  52. database: db,
  53. config,
  54. server,
  55. session: RwLock::new(None),
  56. notifier,
  57. }
  58. }
  59. pub fn init(&self) {
  60. if let Ok(session) = self.get_session() {
  61. self.notifier.notify_login(&session.token, &session.user_id);
  62. }
  63. }
  64. pub fn db_connection(&self) -> Result<DBConnection, FlowyError> {
  65. let user_id = self.get_session()?.user_id;
  66. self.database.get_connection(&user_id)
  67. }
  68. // The caller will be not 'Sync' before of the return value,
  69. // PooledConnection<ConnectionManager> is not sync. You can use
  70. // db_connection_pool function to require the ConnectionPool that is 'Sync'.
  71. //
  72. // let pool = self.db_connection_pool()?;
  73. // let conn: PooledConnection<ConnectionManager> = pool.get()?;
  74. pub fn db_pool(&self) -> Result<Arc<ConnectionPool>, FlowyError> {
  75. let user_id = self.get_session()?.user_id;
  76. self.database.get_pool(&user_id)
  77. }
  78. #[tracing::instrument(level = "debug", skip(self))]
  79. pub async fn sign_in(&self, params: SignInParams) -> Result<UserProfile, FlowyError> {
  80. if self.is_login(&params.email) {
  81. self.user_profile().await
  82. } else {
  83. let resp = self.server.sign_in(params).await?;
  84. let session: Session = resp.clone().into();
  85. let _ = self.set_session(Some(session))?;
  86. let user_table = self.save_user(resp.into()).await?;
  87. let user_profile: UserProfile = user_table.into();
  88. self.notifier.notify_login(&user_profile.token, &user_profile.id);
  89. Ok(user_profile)
  90. }
  91. }
  92. #[tracing::instrument(level = "debug", skip(self))]
  93. pub async fn sign_up(&self, params: SignUpParams) -> Result<UserProfile, FlowyError> {
  94. if self.is_login(&params.email) {
  95. self.user_profile().await
  96. } else {
  97. let resp = self.server.sign_up(params).await?;
  98. let session: Session = resp.clone().into();
  99. let _ = self.set_session(Some(session))?;
  100. let user_table = self.save_user(resp.into()).await?;
  101. let user_profile: UserProfile = user_table.into();
  102. let (ret, mut tx) = mpsc::channel(1);
  103. self.notifier.notify_sign_up(ret, &user_profile);
  104. let _ = tx.recv().await;
  105. Ok(user_profile)
  106. }
  107. }
  108. #[tracing::instrument(level = "debug", skip(self))]
  109. pub async fn sign_out(&self) -> Result<(), FlowyError> {
  110. let session = self.get_session()?;
  111. let _ =
  112. diesel::delete(dsl::user_table.filter(dsl::id.eq(&session.user_id))).execute(&*(self.db_connection()?))?;
  113. let _ = self.database.close_user_db(&session.user_id)?;
  114. let _ = self.set_session(None)?;
  115. self.notifier.notify_logout(&session.token);
  116. let _ = self.sign_out_on_server(&session.token).await?;
  117. Ok(())
  118. }
  119. #[tracing::instrument(level = "debug", skip(self))]
  120. pub async fn update_user(&self, params: UpdateUserParams) -> Result<(), FlowyError> {
  121. let session = self.get_session()?;
  122. let changeset = UserTableChangeset::new(params.clone());
  123. diesel_update_table!(user_table, changeset, &*self.db_connection()?);
  124. let _ = self.update_user_on_server(&session.token, params).await?;
  125. Ok(())
  126. }
  127. pub async fn init_user(&self) -> Result<(), FlowyError> {
  128. Ok(())
  129. }
  130. pub async fn check_user(&self) -> Result<UserProfile, FlowyError> {
  131. let (user_id, token) = self.get_session()?.into_part();
  132. let user = dsl::user_table
  133. .filter(user_table::id.eq(&user_id))
  134. .first::<UserTable>(&*(self.db_connection()?))?;
  135. let _ = self.read_user_profile_on_server(&token)?;
  136. Ok(user.into())
  137. }
  138. pub async fn user_profile(&self) -> Result<UserProfile, FlowyError> {
  139. let (user_id, token) = self.get_session()?.into_part();
  140. let user = dsl::user_table
  141. .filter(user_table::id.eq(&user_id))
  142. .first::<UserTable>(&*(self.db_connection()?))?;
  143. let _ = self.read_user_profile_on_server(&token)?;
  144. Ok(user.into())
  145. }
  146. pub fn user_dir(&self) -> Result<String, FlowyError> {
  147. let session = self.get_session()?;
  148. Ok(format!("{}/{}", self.config.root_dir, session.user_id))
  149. }
  150. pub fn user_id(&self) -> Result<String, FlowyError> {
  151. Ok(self.get_session()?.user_id)
  152. }
  153. pub fn user_name(&self) -> Result<String, FlowyError> {
  154. Ok(self.get_session()?.name)
  155. }
  156. pub fn token(&self) -> Result<String, FlowyError> {
  157. Ok(self.get_session()?.token)
  158. }
  159. }
  160. impl UserSession {
  161. fn read_user_profile_on_server(&self, token: &str) -> Result<(), FlowyError> {
  162. let server = self.server.clone();
  163. let token = token.to_owned();
  164. tokio::spawn(async move {
  165. match server.get_user(&token).await {
  166. Ok(profile) => {
  167. dart_notify(&token, UserNotification::UserProfileUpdated)
  168. .payload(profile)
  169. .send();
  170. }
  171. Err(e) => {
  172. dart_notify(&token, UserNotification::UserProfileUpdated)
  173. .error(e)
  174. .send();
  175. }
  176. }
  177. });
  178. Ok(())
  179. }
  180. async fn update_user_on_server(&self, token: &str, params: UpdateUserParams) -> Result<(), FlowyError> {
  181. let server = self.server.clone();
  182. let token = token.to_owned();
  183. let _ = tokio::spawn(async move {
  184. match server.update_user(&token, params).await {
  185. Ok(_) => {}
  186. Err(e) => {
  187. // TODO: retry?
  188. log::error!("update user profile failed: {:?}", e);
  189. }
  190. }
  191. })
  192. .await;
  193. Ok(())
  194. }
  195. async fn sign_out_on_server(&self, token: &str) -> Result<(), FlowyError> {
  196. let server = self.server.clone();
  197. let token = token.to_owned();
  198. let _ = tokio::spawn(async move {
  199. match server.sign_out(&token).await {
  200. Ok(_) => {}
  201. Err(e) => log::error!("Sign out failed: {:?}", e),
  202. }
  203. })
  204. .await;
  205. Ok(())
  206. }
  207. async fn save_user(&self, user: UserTable) -> Result<UserTable, FlowyError> {
  208. let conn = self.db_connection()?;
  209. let _ = diesel::insert_into(user_table::table)
  210. .values(user.clone())
  211. .execute(&*conn)?;
  212. Ok(user)
  213. }
  214. fn set_session(&self, session: Option<Session>) -> Result<(), FlowyError> {
  215. tracing::debug!("Set user session: {:?}", session);
  216. match &session {
  217. None => KV::remove(&self.config.session_cache_key).map_err(|e| FlowyError::new(ErrorCode::Internal, &e))?,
  218. Some(session) => KV::set_str(&self.config.session_cache_key, session.clone().into()),
  219. }
  220. *self.session.write() = session;
  221. Ok(())
  222. }
  223. fn get_session(&self) -> Result<Session, FlowyError> {
  224. let mut session = { (*self.session.read()).clone() };
  225. if session.is_none() {
  226. match KV::get_str(&self.config.session_cache_key) {
  227. None => {}
  228. Some(s) => {
  229. session = Some(Session::from(s));
  230. let _ = self.set_session(session.clone())?;
  231. }
  232. }
  233. }
  234. match session {
  235. None => Err(FlowyError::unauthorized()),
  236. Some(session) => Ok(session),
  237. }
  238. }
  239. fn is_login(&self, email: &str) -> bool {
  240. match self.get_session() {
  241. Ok(session) => session.email == email,
  242. Err(_) => false,
  243. }
  244. }
  245. }
  246. pub async fn update_user(
  247. _server: Server,
  248. pool: Arc<ConnectionPool>,
  249. params: UpdateUserParams,
  250. ) -> Result<(), FlowyError> {
  251. let changeset = UserTableChangeset::new(params);
  252. let conn = pool.get()?;
  253. diesel_update_table!(user_table, changeset, &*conn);
  254. Ok(())
  255. }
  256. impl UserDatabaseConnection for UserSession {
  257. fn get_connection(&self) -> Result<DBConnection, String> {
  258. self.db_connection().map_err(|e| format!("{:?}", e))
  259. }
  260. }
  261. #[derive(Debug, Clone, Default, Serialize, Deserialize)]
  262. struct Session {
  263. user_id: String,
  264. token: String,
  265. email: String,
  266. #[serde(default)]
  267. name: String,
  268. }
  269. impl std::convert::From<SignInResponse> for Session {
  270. fn from(resp: SignInResponse) -> Self {
  271. Session {
  272. user_id: resp.user_id,
  273. token: resp.token,
  274. email: resp.email,
  275. name: resp.name,
  276. }
  277. }
  278. }
  279. impl std::convert::From<SignUpResponse> for Session {
  280. fn from(resp: SignUpResponse) -> Self {
  281. Session {
  282. user_id: resp.user_id,
  283. token: resp.token,
  284. email: resp.email,
  285. name: resp.name,
  286. }
  287. }
  288. }
  289. impl Session {
  290. pub fn into_part(self) -> (String, String) {
  291. (self.user_id, self.token)
  292. }
  293. }
  294. impl std::convert::From<String> for Session {
  295. fn from(s: String) -> Self {
  296. match serde_json::from_str(&s) {
  297. Ok(s) => s,
  298. Err(e) => {
  299. log::error!("Deserialize string to Session failed: {:?}", e);
  300. Session::default()
  301. }
  302. }
  303. }
  304. }
  305. impl std::convert::From<Session> for String {
  306. fn from(session: Session) -> Self {
  307. match serde_json::to_string(&session) {
  308. Ok(s) => s,
  309. Err(e) => {
  310. log::error!("Serialize session to string failed: {:?}", e);
  311. "".to_string()
  312. }
  313. }
  314. }
  315. }