user.rs 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. use std::str::FromStr;
  2. use std::sync::Arc;
  3. use deadpool_postgres::GenericClient;
  4. use futures::pin_mut;
  5. use futures_util::StreamExt;
  6. use tokio::sync::oneshot::channel;
  7. use tokio_postgres::error::SqlState;
  8. use uuid::Uuid;
  9. use flowy_error::{internal_error, ErrorCode, FlowyError};
  10. use flowy_user::entities::{SignInResponse, SignUpResponse, UpdateUserProfileParams, UserProfile};
  11. use flowy_user::event_map::{UserAuthService, UserCredentials};
  12. use lib_infra::box_any::BoxAny;
  13. use lib_infra::future::FutureResult;
  14. use crate::supabase::entities::{GetUserProfileParams, UserProfileResponse};
  15. use crate::supabase::pg_db::PostgresObject;
  16. use crate::supabase::sql_builder::{SelectSqlBuilder, UpdateSqlBuilder};
  17. use crate::supabase::PostgresServer;
  18. use crate::util::uuid_from_box_any;
  19. pub(crate) const USER_TABLE: &str = "af_user";
  20. pub(crate) const USER_PROFILE_TABLE: &str = "af_user_profile";
  21. pub const USER_UUID: &str = "uuid";
  22. pub struct SupabaseUserAuthServiceImpl {
  23. server: Arc<PostgresServer>,
  24. }
  25. impl SupabaseUserAuthServiceImpl {
  26. pub fn new(server: Arc<PostgresServer>) -> Self {
  27. Self { server }
  28. }
  29. }
  30. impl UserAuthService for SupabaseUserAuthServiceImpl {
  31. fn sign_up(&self, params: BoxAny) -> FutureResult<SignUpResponse, FlowyError> {
  32. let server = self.server.clone();
  33. let (tx, rx) = channel();
  34. tokio::spawn(async move {
  35. tx.send(
  36. async {
  37. let client = server.get_pg_client().await.recv().await?;
  38. let uuid = uuid_from_box_any(params)?;
  39. create_user_with_uuid(&client, uuid).await
  40. }
  41. .await,
  42. )
  43. });
  44. FutureResult::new(async { rx.await.map_err(internal_error)? })
  45. }
  46. fn sign_in(&self, params: BoxAny) -> FutureResult<SignInResponse, FlowyError> {
  47. let server = self.server.clone();
  48. let (tx, rx) = channel();
  49. tokio::spawn(async move {
  50. tx.send(
  51. async {
  52. let client = server.get_pg_client().await.recv().await?;
  53. let uuid = uuid_from_box_any(params)?;
  54. let user_profile = get_user_profile(&client, GetUserProfileParams::Uuid(uuid)).await?;
  55. Ok(SignInResponse {
  56. user_id: user_profile.uid,
  57. workspace_id: user_profile.workspace_id,
  58. ..Default::default()
  59. })
  60. }
  61. .await,
  62. )
  63. });
  64. FutureResult::new(async { rx.await.map_err(internal_error)? })
  65. }
  66. fn sign_out(&self, _token: Option<String>) -> FutureResult<(), FlowyError> {
  67. FutureResult::new(async { Ok(()) })
  68. }
  69. fn update_user(
  70. &self,
  71. _credential: UserCredentials,
  72. params: UpdateUserProfileParams,
  73. ) -> FutureResult<(), FlowyError> {
  74. let server = self.server.clone();
  75. let (tx, rx) = channel();
  76. tokio::spawn(async move {
  77. tx.send(
  78. async move {
  79. let client = server.get_pg_client().await.recv().await?;
  80. update_user_profile(&client, params).await
  81. }
  82. .await,
  83. )
  84. });
  85. FutureResult::new(async { rx.await.map_err(internal_error)? })
  86. }
  87. fn get_user_profile(
  88. &self,
  89. credential: UserCredentials,
  90. ) -> FutureResult<Option<UserProfile>, FlowyError> {
  91. let server = self.server.clone();
  92. let (tx, rx) = channel();
  93. tokio::spawn(async move {
  94. tx.send(
  95. async move {
  96. let client = server.get_pg_client().await.recv().await?;
  97. let uid = credential
  98. .uid
  99. .ok_or(FlowyError::new(ErrorCode::InvalidParams, "uid is required"))?;
  100. let user_profile = get_user_profile(&client, GetUserProfileParams::Uid(uid))
  101. .await
  102. .ok()
  103. .map(|user_profile| UserProfile {
  104. id: user_profile.uid,
  105. email: user_profile.email,
  106. name: user_profile.name,
  107. token: "".to_string(),
  108. icon_url: "".to_string(),
  109. openai_key: "".to_string(),
  110. workspace_id: user_profile.workspace_id,
  111. });
  112. Ok(user_profile)
  113. }
  114. .await,
  115. )
  116. });
  117. FutureResult::new(async { rx.await.map_err(internal_error)? })
  118. }
  119. fn check_user(&self, credential: UserCredentials) -> FutureResult<(), FlowyError> {
  120. let uuid = credential.uuid.and_then(|uuid| Uuid::from_str(&uuid).ok());
  121. let server = self.server.clone();
  122. let (tx, rx) = channel();
  123. tokio::spawn(async move {
  124. tx.send(
  125. async move {
  126. let client = server.get_pg_client().await.recv().await?;
  127. check_user(&client, credential.uid, uuid).await
  128. }
  129. .await,
  130. )
  131. });
  132. FutureResult::new(async { rx.await.map_err(internal_error)? })
  133. }
  134. }
  135. async fn create_user_with_uuid(
  136. client: &PostgresObject,
  137. uuid: Uuid,
  138. ) -> Result<SignUpResponse, FlowyError> {
  139. let mut is_new = true;
  140. if let Err(e) = client
  141. .execute(
  142. &format!("INSERT INTO {} (uuid) VALUES ($1);", USER_TABLE),
  143. &[&uuid],
  144. )
  145. .await
  146. {
  147. if let Some(code) = e.code() {
  148. if code == &SqlState::UNIQUE_VIOLATION {
  149. is_new = false;
  150. } else {
  151. return Err(FlowyError::new(ErrorCode::PgDatabaseError, e));
  152. }
  153. }
  154. };
  155. let user_profile = get_user_profile(client, GetUserProfileParams::Uuid(uuid)).await?;
  156. Ok(SignUpResponse {
  157. user_id: user_profile.uid,
  158. name: user_profile.name,
  159. workspace_id: user_profile.workspace_id,
  160. is_new,
  161. email: Some(user_profile.email),
  162. token: None,
  163. })
  164. }
  165. async fn get_user_profile(
  166. client: &PostgresObject,
  167. params: GetUserProfileParams,
  168. ) -> Result<UserProfileResponse, FlowyError> {
  169. let rows = match params {
  170. GetUserProfileParams::Uid(uid) => {
  171. let stmt = client
  172. .prepare_cached(&format!(
  173. "SELECT * FROM {} WHERE uid = $1",
  174. USER_PROFILE_TABLE
  175. ))
  176. .await
  177. .map_err(|e| FlowyError::new(ErrorCode::PgDatabaseError, e))?;
  178. client
  179. .query(&stmt, &[&uid])
  180. .await
  181. .map_err(|e| FlowyError::new(ErrorCode::PgDatabaseError, e))?
  182. },
  183. GetUserProfileParams::Uuid(uuid) => {
  184. let stmt = client
  185. .prepare_cached(&format!(
  186. "SELECT * FROM {} WHERE uuid = $1",
  187. USER_PROFILE_TABLE
  188. ))
  189. .await
  190. .map_err(|e| FlowyError::new(ErrorCode::PgDatabaseError, e))?;
  191. client
  192. .query(&stmt, &[&uuid])
  193. .await
  194. .map_err(|e| FlowyError::new(ErrorCode::PgDatabaseError, e))?
  195. },
  196. };
  197. let mut user_profiles = rows
  198. .into_iter()
  199. .map(UserProfileResponse::from)
  200. .collect::<Vec<_>>();
  201. if user_profiles.is_empty() {
  202. Err(FlowyError::record_not_found())
  203. } else {
  204. Ok(user_profiles.remove(0))
  205. }
  206. }
  207. async fn update_user_profile(
  208. client: &PostgresObject,
  209. params: UpdateUserProfileParams,
  210. ) -> Result<(), FlowyError> {
  211. if params.is_empty() {
  212. return Err(FlowyError::new(
  213. ErrorCode::InvalidParams,
  214. format!("Update user profile params is empty: {:?}", params),
  215. ));
  216. }
  217. let (sql, pg_params) = UpdateSqlBuilder::new(USER_PROFILE_TABLE)
  218. .set("name", params.name)
  219. .set("email", params.email)
  220. .where_clause("uid", params.id)
  221. .build();
  222. let stmt = client.prepare_cached(&sql).await.map_err(|e| {
  223. FlowyError::new(
  224. ErrorCode::PgDatabaseError,
  225. format!("Prepare update user profile sql error: {}", e),
  226. )
  227. })?;
  228. let affect_rows = client
  229. .execute_raw(&stmt, pg_params)
  230. .await
  231. .map_err(|e| FlowyError::new(ErrorCode::PgDatabaseError, e))?;
  232. tracing::trace!("Update user profile affect rows: {}", affect_rows);
  233. Ok(())
  234. }
  235. async fn check_user(
  236. client: &PostgresObject,
  237. uid: Option<i64>,
  238. uuid: Option<Uuid>,
  239. ) -> Result<(), FlowyError> {
  240. if uid.is_none() && uuid.is_none() {
  241. return Err(FlowyError::new(
  242. ErrorCode::InvalidParams,
  243. "uid and uuid can't be both empty",
  244. ));
  245. }
  246. let (sql, params) = match uid {
  247. None => SelectSqlBuilder::new(USER_TABLE)
  248. .where_clause("uuid", uuid.unwrap())
  249. .build(),
  250. Some(uid) => SelectSqlBuilder::new(USER_TABLE)
  251. .where_clause("uid", uid)
  252. .build(),
  253. };
  254. let stmt = client
  255. .prepare_cached(&sql)
  256. .await
  257. .map_err(|e| FlowyError::new(ErrorCode::PgDatabaseError, e))?;
  258. let rows = Box::pin(
  259. client
  260. .query_raw(&stmt, params)
  261. .await
  262. .map_err(|e| FlowyError::new(ErrorCode::PgDatabaseError, e))?,
  263. );
  264. pin_mut!(rows);
  265. // TODO(nathan): would it be better to use token.
  266. if rows.next().await.is_some() {
  267. Ok(())
  268. } else {
  269. Err(FlowyError::new(
  270. ErrorCode::UserNotExist,
  271. "Can't find the user in pg database",
  272. ))
  273. }
  274. }