appflowy 4 éve
szülő
commit
8be53b323e

+ 2 - 0
backend/Cargo.toml

@@ -29,6 +29,8 @@ protobuf = {version = "2.20.0"}
 uuid = { version = "0.8", features = ["serde", "v4"] }
 config = { version = "0.10.1", default-features = false, features = ["yaml"] }
 chrono = "0.4.19"
+anyhow = "1.0.40"
+thiserror = "1.0.24"
 
 flowy-log = { path = "../rust-lib/flowy-log" }
 flowy-user = { path = "../rust-lib/flowy-user" }

+ 18 - 27
backend/src/application.rs

@@ -2,7 +2,6 @@ use crate::{
     config::{get_configuration, DatabaseSettings, Settings},
     context::AppContext,
     routers::*,
-    user_service::Auth,
     ws_service::WSServer,
 };
 use actix::Actor;
@@ -13,39 +12,37 @@ use std::{net::TcpListener, sync::Arc};
 pub struct Application {
     port: u16,
     server: Server,
-    app_ctx: Arc<AppContext>,
 }
 
 impl Application {
     pub async fn build(configuration: Settings) -> Result<Self, std::io::Error> {
-        let app_ctx = init_app_context(&configuration).await;
         let address = format!(
             "{}:{}",
             configuration.application.host, configuration.application.port
         );
         let listener = TcpListener::bind(&address)?;
         let port = listener.local_addr().unwrap().port();
-        let server = run(listener, app_ctx.clone())?;
-        Ok(Self {
-            port,
-            server,
-            app_ctx,
-        })
+        let app_ctx = init_app_context(&configuration).await;
+        let server = run(listener, app_ctx)?;
+        Ok(Self { port, server })
     }
 
     pub async fn run_until_stopped(self) -> Result<(), std::io::Error> { self.server.await }
 }
 
-pub fn run(listener: TcpListener, app_ctx: Arc<AppContext>) -> Result<Server, std::io::Error> {
+pub fn run(listener: TcpListener, app_ctx: AppContext) -> Result<Server, std::io::Error> {
+    let AppContext { ws_server, pg_pool } = app_ctx;
+    let ws_server = Data::new(ws_server);
+    let pg_pool = Data::new(pg_pool);
+
     let server = HttpServer::new(move || {
         App::new()
             .wrap(middleware::Logger::default())
             .app_data(web::JsonConfig::default().limit(4096))
             .service(ws_scope())
             .service(user_scope())
-            .app_data(Data::new(app_ctx.ws_server.clone()))
-            .app_data(Data::new(app_ctx.db_pool.clone()))
-            .app_data(Data::new(app_ctx.auth.clone()))
+            .app_data(ws_server.clone())
+            .app_data(pg_pool.clone())
     })
     .listen(listener)?
     .run();
@@ -58,24 +55,18 @@ fn user_scope() -> Scope {
     web::scope("/user").service(web::resource("/register").route(web::post().to(user::register)))
 }
 
-async fn init_app_context(configuration: &Settings) -> Arc<AppContext> {
+async fn init_app_context(configuration: &Settings) -> AppContext {
     let _ = flowy_log::Builder::new("flowy").env_filter("Debug").build();
-    let pg_pool = Arc::new(
-        get_connection_pool(&configuration.database)
-            .await
-            .expect(&format!(
-                "Failed to connect to Postgres {:?}.",
-                configuration.database
-            )),
-    );
+    let pg_pool = get_connection_pool(&configuration.database)
+        .await
+        .expect(&format!(
+            "Failed to connect to Postgres at {:?}.",
+            configuration.database
+        ));
 
     let ws_server = WSServer::new().start();
 
-    let auth = Arc::new(Auth::new(pg_pool.clone()));
-
-    let ctx = AppContext::new(ws_server, pg_pool, auth);
-
-    Arc::new(ctx)
+    AppContext::new(ws_server, pg_pool)
 }
 
 pub async fn get_connection_pool(configuration: &DatabaseSettings) -> Result<PgPool, sqlx::Error> {

+ 4 - 6
backend/src/context.rs

@@ -1,4 +1,4 @@
-use crate::{user_service::Auth, ws_service::WSServer};
+use crate::ws_service::WSServer;
 use actix::Addr;
 
 use sqlx::PgPool;
@@ -6,16 +6,14 @@ use std::sync::Arc;
 
 pub struct AppContext {
     pub ws_server: Addr<WSServer>,
-    pub db_pool: Arc<PgPool>,
-    pub auth: Arc<Auth>,
+    pub pg_pool: PgPool,
 }
 
 impl AppContext {
-    pub fn new(ws_server: Addr<WSServer>, db_pool: Arc<PgPool>, auth: Arc<Auth>) -> Self {
+    pub fn new(ws_server: Addr<WSServer>, db_pool: PgPool) -> Self {
         AppContext {
             ws_server,
-            db_pool,
-            auth,
+            pg_pool: db_pool,
         }
     }
 }

+ 2 - 5
backend/src/routers/helper.rs

@@ -20,14 +20,11 @@ pub fn parse_from_bytes<T: Message>(bytes: &[u8]) -> Result<T, ServerError> {
 pub async fn poll_payload(mut payload: web::Payload) -> Result<web::BytesMut, ServerError> {
     let mut body = web::BytesMut::new();
     while let Some(chunk) = payload.next().await {
-        let chunk = chunk.map_err(|e| ServerError {
-            code: ServerCode::InternalError,
-            msg: format!("{:?}", e),
-        })?;
+        let chunk = chunk.map_err(ServerError::internal)?;
 
         if (body.len() + chunk.len()) > MAX_PAYLOAD_SIZE {
             return Err(ServerError {
-                code: ServerCode::PayloadOverflow,
+                code: Code::PayloadOverflow,
                 msg: "Payload overflow".to_string(),
             });
         }

+ 5 - 3
backend/src/routers/user.rs

@@ -1,4 +1,4 @@
-use crate::{routers::helper::parse_from_payload, user_service::Auth};
+use crate::routers::helper::parse_from_payload;
 use actix_web::{
     web::{Data, Payload},
     Error,
@@ -8,15 +8,17 @@ use actix_web::{
 use flowy_net::response::*;
 use flowy_user::protobuf::SignUpParams;
 
+use crate::user_service::sign_up;
+use sqlx::PgPool;
 use std::sync::Arc;
 
 pub async fn register(
     _request: HttpRequest,
     payload: Payload,
-    auth: Data<Arc<Auth>>,
+    pool: Data<PgPool>,
 ) -> Result<HttpResponse, ServerError> {
     let params: SignUpParams = parse_from_payload(payload).await?;
-    let resp = auth.sign_up(params).await?;
+    let resp = sign_up(pool.get_ref(), params).await?;
 
     Ok(resp.into())
 }

+ 61 - 32
backend/src/user_service/auth.rs

@@ -1,44 +1,73 @@
+use anyhow::Context;
 use chrono::Utc;
-use flowy_net::response::{FlowyResponse, ServerCode, ServerError};
+use flowy_net::response::{Code, FlowyResponse, ServerError};
 use flowy_user::{entities::SignUpResponse, protobuf::SignUpParams};
-use sqlx::PgPool;
+use sqlx::{Error, PgPool, Postgres, Transaction};
 use std::sync::Arc;
 
-pub struct Auth {
-    db_pool: Arc<PgPool>,
+pub async fn sign_up(pool: &PgPool, params: SignUpParams) -> Result<FlowyResponse, ServerError> {
+    let mut transaction = pool
+        .begin()
+        .await
+        .context("Failed to acquire a Postgres connection from the pool")?;
+
+    let _ = is_email_exist(&mut transaction, &params.email).await?;
+
+    let data = insert_user(&mut transaction, params)
+        .await
+        .context("Failed to insert user")?;
+
+    let response = FlowyResponse::success(data).context("Failed to generate response")?;
+
+    Ok(response)
 }
 
-impl Auth {
-    pub fn new(db_pool: Arc<PgPool>) -> Self { Self { db_pool } }
+async fn is_email_exist(
+    transaction: &mut Transaction<'_, Postgres>,
+    email: &str,
+) -> Result<(), ServerError> {
+    let result = sqlx::query!(
+        r#"SELECT email FROM user_table WHERE email = $1"#,
+        email.to_string()
+    )
+    .fetch_optional(transaction)
+    .await
+    .map_err(ServerError::internal)?;
 
-    pub async fn sign_up(&self, params: SignUpParams) -> Result<FlowyResponse, ServerError> {
-        // email exist?
-        // generate user id
-        let uuid = uuid::Uuid::new_v4();
-        let result = sqlx::query!(
-            r#"
+    match result {
+        Some(_) => Err(ServerError {
+            code: Code::EmailAlreadyExists,
+            msg: format!("{} already exists", email),
+        }),
+        None => Ok(()),
+    }
+}
+
+async fn insert_user(
+    transaction: &mut Transaction<'_, Postgres>,
+    params: SignUpParams,
+) -> Result<SignUpResponse, ServerError> {
+    let uuid = uuid::Uuid::new_v4();
+    let result = sqlx::query!(
+        r#"
             INSERT INTO user_table (id, email, name, create_time, password)
             VALUES ($1, $2, $3, $4, $5)
         "#,
-            uuid,
-            params.email,
-            params.name,
-            Utc::now(),
-            "123".to_string()
-        )
-        .execute(self.db_pool.as_ref())
-        .await;
-
-        let data = SignUpResponse {
-            uid: uuid.to_string(),
-            name: params.name,
-            email: params.email,
-        };
-
-        let response = FlowyResponse::from(data, "", ServerCode::Success)?;
-
-        Ok(response)
-    }
+        uuid,
+        params.email,
+        params.name,
+        Utc::now(),
+        "123".to_string()
+    )
+    .execute(transaction)
+    .await
+    .map_err(ServerError::internal)?;
+
+    let data = SignUpResponse {
+        uid: uuid.to_string(),
+        name: params.name,
+        email: params.email,
+    };
 
-    pub fn is_email_exist(&self, email: &str) -> bool { true }
+    Ok(data)
 }

+ 2 - 1
rust-lib/flowy-net/Cargo.toml

@@ -21,7 +21,8 @@ tokio = { version = "1", features = ["full"] }
 actix-web = {version = "4.0.0-beta.8", optional = true}
 derive_more = {version = "0.99", features = ["display"]}
 flowy-derive = { path = "../flowy-derive" }
-
+anyhow = "1.0"
+thiserror = "1.0.24"
 
 [features]
 http = ["actix-web"]

+ 13 - 34
rust-lib/flowy-net/src/request/request.rs

@@ -1,4 +1,4 @@
-use crate::response::{FlowyResponse, ServerCode, ServerError};
+use crate::response::{Code, FlowyResponse, ServerError};
 use bytes::Bytes;
 use hyper::http;
 use protobuf::{Message, ProtobufError};
@@ -25,41 +25,20 @@ where
     });
 
     let response = rx.await??;
-    if response.status() == http::StatusCode::OK {
-        let response_bytes = response.bytes().await?;
-        let flowy_resp: FlowyResponse = serde_json::from_slice(&response_bytes).unwrap();
-        let data = T2::try_from(flowy_resp.data)?;
-        Ok(data)
-    } else {
-        Err(ServerError {
-            code: ServerCode::InternalError,
-            msg: format!("{:?}", response),
-        })
-    }
-}
-
-async fn parse_response<T>(response: Response) -> Result<T, ServerError>
-where
-    T: Message,
-{
-    let bytes = response.bytes().await?;
-    parse_bytes(bytes)
+    let data = get_response_data(response).await?;
+    Ok(T2::try_from(data)?)
 }
 
-fn parse_bytes<T>(bytes: Bytes) -> Result<T, ServerError>
-where
-    T: Message,
-{
-    match Message::parse_from_bytes(&bytes) {
-        Ok(data) => Ok(data),
-        Err(e) => {
-            log::error!(
-                "Parse bytes for {:?} failed: {}",
-                std::any::type_name::<T>(),
-                e
-            );
-            Err(e.into())
-        },
+async fn get_response_data(original: Response) -> Result<Bytes, ServerError> {
+    if original.status() == http::StatusCode::OK {
+        let bytes = original.bytes().await?;
+        let response: FlowyResponse = serde_json::from_slice(&bytes)?;
+        match response.error {
+            None => Ok(response.data),
+            Some(error) => Err(error),
+        }
+    } else {
+        Err(ServerError::http(original))
     }
 }
 

+ 59 - 46
rust-lib/flowy-net/src/response/response.rs

@@ -1,15 +1,30 @@
 use bytes::Bytes;
 use serde::{Deserialize, Serialize, __private::Formatter};
 use serde_repr::*;
-use std::{convert::TryInto, error::Error, fmt};
+use std::{convert::TryInto, error::Error, fmt, fmt::Debug};
 use tokio::sync::oneshot::error::RecvError;
 
-#[derive(Debug)]
+#[derive(thiserror::Error, Debug, Serialize, Deserialize, Clone)]
 pub struct ServerError {
-    pub code: ServerCode,
+    pub code: Code,
     pub msg: String,
 }
 
+macro_rules! static_error {
+    ($name:ident, $status:expr) => {
+        #[allow(non_snake_case, missing_docs)]
+        pub fn $name<T: Debug>(error: T) -> ServerError {
+            let msg = format!("{:?}", error);
+            ServerError { code: $status, msg }
+        }
+    };
+}
+
+impl ServerError {
+    static_error!(internal, Code::InternalError);
+    static_error!(http, Code::HttpError);
+}
+
 impl std::fmt::Display for ServerError {
     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
         let msg = format!("{:?}:{}", self.code, self.msg);
@@ -20,89 +35,87 @@ impl std::fmt::Display for ServerError {
 impl std::convert::From<&ServerError> for FlowyResponse {
     fn from(error: &ServerError) -> Self {
         FlowyResponse {
-            msg: error.msg.clone(),
             data: Bytes::from(vec![]),
-            code: error.code.clone(),
+            error: Some(error.clone()),
         }
     }
 }
 
 #[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug, Clone)]
 #[repr(u16)]
-pub enum ServerCode {
-    Success          = 0,
-    InvalidToken     = 1,
-    InternalError    = 2,
-    Unauthorized     = 3,
-    PayloadOverflow  = 4,
-    PayloadSerdeFail = 5,
-    ProtobufError    = 6,
-    SerdeError       = 7,
-    ConnectRefused   = 8,
-    ConnectTimeout   = 9,
-    ConnectClose     = 10,
-    ConnectCancel    = 11,
+pub enum Code {
+    InvalidToken       = 1,
+    Unauthorized       = 3,
+    PayloadOverflow    = 4,
+    PayloadSerdeFail   = 5,
+
+    ProtobufError      = 6,
+    SerdeError         = 7,
+
+    EmailAlreadyExists = 50,
+
+    ConnectRefused     = 100,
+    ConnectTimeout     = 101,
+    ConnectClose       = 102,
+    ConnectCancel      = 103,
+
+    SqlError           = 200,
+
+    HttpError          = 300,
+
+    InternalError      = 1000,
 }
 
 #[derive(Debug, Serialize, Deserialize)]
 pub struct FlowyResponse {
-    pub msg: String,
     pub data: Bytes,
-    pub code: ServerCode,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub error: Option<ServerError>,
 }
 
 impl FlowyResponse {
-    pub fn new(data: Bytes, msg: &str, code: ServerCode) -> Self {
-        FlowyResponse {
-            msg: msg.to_owned(),
-            data,
-            code,
-        }
-    }
+    pub fn new(data: Bytes, error: Option<ServerError>) -> Self { FlowyResponse { data, error } }
 
-    pub fn from<T: TryInto<Bytes, Error = protobuf::ProtobufError>>(
+    pub fn success<T: TryInto<Bytes, Error = protobuf::ProtobufError>>(
         data: T,
-        msg: &str,
-        code: ServerCode,
     ) -> Result<Self, ServerError> {
         let bytes: Bytes = data.try_into()?;
-        Ok(Self::new(bytes, msg, code))
+        Ok(Self::new(bytes, None))
     }
 }
 
 impl std::convert::From<protobuf::ProtobufError> for ServerError {
     fn from(err: protobuf::ProtobufError) -> Self {
         ServerError {
-            code: ServerCode::ProtobufError,
+            code: Code::ProtobufError,
             msg: format!("{}", err),
         }
     }
 }
 
 impl std::convert::From<RecvError> for ServerError {
-    fn from(error: RecvError) -> Self {
-        ServerError {
-            code: ServerCode::InternalError,
-            msg: format!("{:?}", error),
-        }
-    }
+    fn from(error: RecvError) -> Self { ServerError::internal(error) }
 }
 
 impl std::convert::From<serde_json::Error> for ServerError {
     fn from(e: serde_json::Error) -> Self {
         let msg = format!("Serial error: {:?}", e);
         ServerError {
-            code: ServerCode::SerdeError,
+            code: Code::SerdeError,
             msg,
         }
     }
 }
 
+impl std::convert::From<anyhow::Error> for ServerError {
+    fn from(error: anyhow::Error) -> Self { ServerError::internal(error) }
+}
+
 impl std::convert::From<reqwest::Error> for ServerError {
     fn from(error: reqwest::Error) -> Self {
         if error.is_timeout() {
             return ServerError {
-                code: ServerCode::ConnectTimeout,
+                code: Code::ConnectTimeout,
                 msg: format!("{}", error),
             };
         }
@@ -111,22 +124,22 @@ impl std::convert::From<reqwest::Error> for ServerError {
             let hyper_error: Option<&hyper::Error> = error.source().unwrap().downcast_ref();
             return match hyper_error {
                 None => ServerError {
-                    code: ServerCode::ConnectRefused,
+                    code: Code::ConnectRefused,
                     msg: format!("{:?}", error),
                 },
                 Some(hyper_error) => {
-                    let mut code = ServerCode::InternalError;
+                    let mut code = Code::InternalError;
                     let msg = format!("{}", error);
                     if hyper_error.is_closed() {
-                        code = ServerCode::ConnectClose;
+                        code = Code::ConnectClose;
                     }
 
                     if hyper_error.is_connect() {
-                        code = ServerCode::ConnectRefused;
+                        code = Code::ConnectRefused;
                     }
 
                     if hyper_error.is_canceled() {
-                        code = ServerCode::ConnectCancel;
+                        code = Code::ConnectCancel;
                     }
 
                     if hyper_error.is_timeout() {}
@@ -138,7 +151,7 @@ impl std::convert::From<reqwest::Error> for ServerError {
 
         let msg = format!("{:?}", error);
         ServerError {
-            code: ServerCode::ProtobufError,
+            code: Code::ProtobufError,
             msg,
         }
     }

+ 1 - 1
rust-lib/flowy-user/tests/server/user_test.rs

@@ -7,7 +7,7 @@ async fn user_register_test() {
     let params = SignUpParams {
         email: "[email protected]".to_string(),
         name: "annie".to_string(),
-        password: "123".to_string(),
+        password: "1233333".to_string(),
     };
     let result = server.sign_up(params).await.unwrap();
     println!("{:?}", result);