user_session.rs 12 KB


  1. use crate::{
  2. entities::{SignInParams, SignUpParams, UpdateUserParams, UserProfile},
  3. errors::{ErrorCode, UserError},
  4. services::user::database::UserDB,
  5. sql_tables::{UserTable, UserTableChangeset},
  6. };
  7. use crate::{
  8. notify::*,
  9. services::server::{construct_user_server, Server},
  10. };
  11. use flowy_database::{
  12. query_dsl::*,
  13. schema::{user_table, user_table::dsl},
  14. DBConnection,
  15. ExpressionMethods,
  16. UserDatabaseConnection,
  17. };
  18. use flowy_infra::kv::KV;
  19. use flowy_net::config::ServerConfig;
  20. use flowy_sqlite::ConnectionPool;
  21. use flowy_ws::{WsController, WsMessageHandler, WsState};
  22. use parking_lot::RwLock;
  23. use serde::{Deserialize, Serialize};
  24. use std::sync::Arc;
  25. pub struct UserSessionConfig {
  26. root_dir: String,
  27. server_config: ServerConfig,
  28. }
  29. impl UserSessionConfig {
  30. pub fn new(root_dir: &str, server_config: &ServerConfig) -> Self {
  31. Self {
  32. root_dir: root_dir.to_owned(),
  33. server_config: server_config.clone(),
  34. }
  35. }
  36. }
  37. pub enum SessionStatus {
  38. Login { token: String },
  39. Expired { token: String },
  40. }
  41. pub type SessionStatusCallback = Arc<dyn Fn(SessionStatus) + Send + Sync>;
  42. pub struct UserSession {
  43. database: UserDB,
  44. config: UserSessionConfig,
  45. #[allow(dead_code)]
  46. server: Server,
  47. session: RwLock<Option<Session>>,
  48. pub ws_controller: Arc<WsController>,
  49. status_callback: SessionStatusCallback,
  50. }
  51. impl UserSession {
  52. pub fn new(config: UserSessionConfig, status_callback: SessionStatusCallback) -> Self {
  53. let db = UserDB::new(&config.root_dir);
  54. let server = construct_user_server(&config.server_config);
  55. let ws_controller = Arc::new(WsController::new());
  56. let user_session = Self {
  57. database: db,
  58. config,
  59. server,
  60. session: RwLock::new(None),
  61. ws_controller,
  62. status_callback,
  63. };
  64. user_session
  65. }
  66. pub fn db_connection(&self) -> Result<DBConnection, UserError> {
  67. let user_id = self.get_session()?.user_id;
  68. self.database.get_connection(&user_id)
  69. }
  70. // The caller will be not 'Sync' before of the return value,
  71. // PooledConnection<ConnectionManager> is not sync. You can use
  72. // db_connection_pool function to require the ConnectionPool that is 'Sync'.
  73. //
  74. // let pool = self.db_connection_pool()?;
  75. // let conn: PooledConnection<ConnectionManager> = pool.get()?;
  76. pub fn db_pool(&self) -> Result<Arc<ConnectionPool>, UserError> {
  77. let user_id = self.get_session()?.user_id;
  78. self.database.get_pool(&user_id)
  79. }
  80. #[tracing::instrument(level = "debug", skip(self))]
  81. pub async fn sign_in(&self, params: SignInParams) -> Result<UserProfile, UserError> {
  82. if self.is_login(&params.email) {
  83. self.user_profile().await
  84. } else {
  85. let resp = self.server.sign_in(params).await?;
  86. let session = Session::new(&resp.user_id, &resp.token, &resp.email);
  87. let _ = self.set_session(Some(session))?;
  88. let user_table = self.save_user(resp.into()).await?;
  89. let user_profile = UserProfile::from(user_table);
  90. (self.status_callback)(SessionStatus::Login {
  91. token: user_profile.token.clone(),
  92. });
  93. Ok(user_profile)
  94. }
  95. }
  96. #[tracing::instrument(level = "debug", skip(self))]
  97. pub async fn sign_up(&self, params: SignUpParams) -> Result<UserProfile, UserError> {
  98. if self.is_login(&params.email) {
  99. self.user_profile().await
  100. } else {
  101. let resp = self.server.sign_up(params).await?;
  102. let session = Session::new(&resp.user_id, &resp.token, &resp.email);
  103. let _ = self.set_session(Some(session))?;
  104. let user_table = self.save_user(resp.into()).await?;
  105. let user_profile = UserProfile::from(user_table);
  106. (self.status_callback)(SessionStatus::Login {
  107. token: user_profile.token.clone(),
  108. });
  109. Ok(user_profile)
  110. }
  111. }
  112. #[tracing::instrument(level = "debug", skip(self))]
  113. pub async fn sign_out(&self) -> Result<(), UserError> {
  114. let session = self.get_session()?;
  115. let _ =
  116. diesel::delete(dsl::user_table.filter(dsl::id.eq(&session.user_id))).execute(&*(self.db_connection()?))?;
  117. let _ = self.database.close_user_db(&session.user_id)?;
  118. let _ = self.set_session(None)?;
  119. (self.status_callback)(SessionStatus::Expired {
  120. token: session.token.clone(),
  121. });
  122. let _ = self.sign_out_on_server(&session.token).await?;
  123. Ok(())
  124. }
  125. #[tracing::instrument(level = "debug", skip(self))]
  126. pub async fn update_user(&self, params: UpdateUserParams) -> Result<(), UserError> {
  127. let session = self.get_session()?;
  128. let changeset = UserTableChangeset::new(params.clone());
  129. diesel_update_table!(user_table, changeset, &*self.db_connection()?);
  130. let _ = self.update_user_on_server(&session.token, params).await?;
  131. Ok(())
  132. }
  133. pub async fn init_user(&self) -> Result<(), UserError> {
  134. let (_, token) = self.get_session()?.into_part();
  135. let _ = self.start_ws_connection(&token).await?;
  136. Ok(())
  137. }
  138. pub async fn check_user(&self) -> Result<UserProfile, UserError> {
  139. let (user_id, token) = self.get_session()?.into_part();
  140. let user = dsl::user_table
  141. .filter(user_table::id.eq(&user_id))
  142. .first::<UserTable>(&*(self.db_connection()?))?;
  143. let _ = self.read_user_profile_on_server(&token)?;
  144. Ok(UserProfile::from(user))
  145. }
  146. pub async fn user_profile(&self) -> Result<UserProfile, UserError> {
  147. let (user_id, token) = self.get_session()?.into_part();
  148. let user = dsl::user_table
  149. .filter(user_table::id.eq(&user_id))
  150. .first::<UserTable>(&*(self.db_connection()?))?;
  151. let _ = self.read_user_profile_on_server(&token)?;
  152. Ok(UserProfile::from(user))
  153. }
  154. pub fn user_dir(&self) -> Result<String, UserError> {
  155. let session = self.get_session()?;
  156. Ok(format!("{}/{}", self.config.root_dir, session.user_id))
  157. }
  158. pub fn user_id(&self) -> Result<String, UserError> { Ok(self.get_session()?.user_id) }
  159. pub fn token(&self) -> Result<String, UserError> { Ok(self.get_session()?.token) }
  160. pub fn add_ws_handler(&self, handler: Arc<dyn WsMessageHandler>) {
  161. let _ = self.ws_controller.add_handler(handler);
  162. }
  163. }
  164. impl UserSession {
  165. fn read_user_profile_on_server(&self, token: &str) -> Result<(), UserError> {
  166. let server = self.server.clone();
  167. let token = token.to_owned();
  168. tokio::spawn(async move {
  169. match server.get_user(&token).await {
  170. Ok(profile) => {
  171. dart_notify(&token, UserObservable::UserProfileUpdated)
  172. .payload(profile)
  173. .send();
  174. },
  175. Err(e) => {
  176. dart_notify(&token, UserObservable::UserProfileUpdated).error(e).send();
  177. },
  178. }
  179. });
  180. Ok(())
  181. }
  182. async fn update_user_on_server(&self, token: &str, params: UpdateUserParams) -> Result<(), UserError> {
  183. let server = self.server.clone();
  184. let token = token.to_owned();
  185. let _ = tokio::spawn(async move {
  186. match server.update_user(&token, params).await {
  187. Ok(_) => {},
  188. Err(e) => {
  189. // TODO: retry?
  190. log::error!("update user profile failed: {:?}", e);
  191. },
  192. }
  193. })
  194. .await;
  195. Ok(())
  196. }
  197. async fn sign_out_on_server(&self, token: &str) -> Result<(), UserError> {
  198. let server = self.server.clone();
  199. let token = token.to_owned();
  200. let _ = tokio::spawn(async move {
  201. match server.sign_out(&token).await {
  202. Ok(_) => {},
  203. Err(e) => log::error!("Sign out failed: {:?}", e),
  204. }
  205. })
  206. .await;
  207. Ok(())
  208. }
  209. async fn save_user(&self, user: UserTable) -> Result<UserTable, UserError> {
  210. let conn = self.db_connection()?;
  211. let _ = diesel::insert_into(user_table::table)
  212. .values(user.clone())
  213. .execute(&*conn)?;
  214. Ok(user)
  215. }
  216. fn set_session(&self, session: Option<Session>) -> Result<(), UserError> {
  217. log::debug!("Set user session: {:?}", session);
  218. match &session {
  219. None => KV::remove(SESSION_CACHE_KEY).map_err(|e| UserError::new(ErrorCode::InternalError, &e))?,
  220. Some(session) => KV::set_str(SESSION_CACHE_KEY, session.clone().into()),
  221. }
  222. *self.session.write() = session;
  223. Ok(())
  224. }
  225. fn get_session(&self) -> Result<Session, UserError> {
  226. let mut session = { (*self.session.read()).clone() };
  227. if session.is_none() {
  228. match KV::get_str(SESSION_CACHE_KEY) {
  229. None => {},
  230. Some(s) => {
  231. session = Some(Session::from(s));
  232. let _ = self.set_session(session.clone())?;
  233. },
  234. }
  235. }
  236. match session {
  237. None => Err(UserError::unauthorized()),
  238. Some(session) => Ok(session),
  239. }
  240. }
  241. fn is_login(&self, email: &str) -> bool {
  242. match self.get_session() {
  243. Ok(session) => session.email == email,
  244. Err(_) => false,
  245. }
  246. }
  247. #[tracing::instrument(level = "debug", skip(self, token))]
  248. pub async fn start_ws_connection(&self, token: &str) -> Result<(), UserError> {
  249. let addr = format!("{}/{}", self.server.ws_addr(), token);
  250. self.listen_on_websocket();
  251. let _ = self.ws_controller.start_connect(addr).await?;
  252. Ok(())
  253. }
  254. #[tracing::instrument(level = "debug", skip(self))]
  255. fn listen_on_websocket(&self) {
  256. let mut notify = self.ws_controller.state_subscribe();
  257. let ws_controller = self.ws_controller.clone();
  258. let _ = tokio::spawn(async move {
  259. loop {
  260. match notify.recv().await {
  261. Ok(state) => {
  262. log::info!("Websocket state changed: {}", state);
  263. match state {
  264. WsState::Init => {},
  265. WsState::Connected(_) => {},
  266. WsState::Disconnected(_) => {
  267. match ws_controller.retry().await {
  268. Ok(_) => {},
  269. Err(e) => {
  270. log::error!("Retry websocket connect failed: {:?}", e);
  271. }
  272. }
  273. },
  274. }
  275. },
  276. Err(e) => {
  277. log::error!("Websocket state notify error: {:?}", e);
  278. break;
  279. },
  280. }
  281. }
  282. });
  283. }
  284. }
  285. pub async fn update_user(
  286. _server: Server,
  287. pool: Arc<ConnectionPool>,
  288. params: UpdateUserParams,
  289. ) -> Result<(), UserError> {
  290. let changeset = UserTableChangeset::new(params);
  291. let conn = pool.get()?;
  292. diesel_update_table!(user_table, changeset, &*conn);
  293. Ok(())
  294. }
  295. impl UserDatabaseConnection for UserSession {
  296. fn get_connection(&self) -> Result<DBConnection, String> { self.db_connection().map_err(|e| format!("{:?}", e)) }
  297. }
  298. const SESSION_CACHE_KEY: &str = "session_cache_key";
  299. #[derive(Debug, Clone, Default, Serialize, Deserialize)]
  300. struct Session {
  301. user_id: String,
  302. token: String,
  303. email: String,
  304. }
  305. impl Session {
  306. pub fn new(user_id: &str, token: &str, email: &str) -> Self {
  307. Self {
  308. user_id: user_id.to_owned(),
  309. token: token.to_owned(),
  310. email: email.to_owned(),
  311. }
  312. }
  313. pub fn into_part(self) -> (String, String) { (self.user_id, self.token) }
  314. }
  315. impl std::convert::From<String> for Session {
  316. fn from(s: String) -> Self {
  317. match serde_json::from_str(&s) {
  318. Ok(s) => s,
  319. Err(e) => {
  320. log::error!("Deserialize string to Session failed: {:?}", e);
  321. Session::default()
  322. },
  323. }
  324. }
  325. }
  326. impl std::convert::Into<String> for Session {
  327. fn into(self) -> String {
  328. match serde_json::to_string(&self) {
  329. Ok(s) => s,
  330. Err(e) => {
  331. log::error!("Serialize session to string failed: {:?}", e);
  332. "".to_string()
  333. },
  334. }
  335. }
  336. }