Ver Fonte

add user sign up handler

appflowy há 3 anos atrás
pai
commit
89c5d5468e

+ 3 - 1
backend/Cargo.toml

@@ -12,7 +12,6 @@ actix-http = "2.2.1"
 actix-web-actors = "3"
 actix-codec = "0.3"
 
-
 futures = "0.3.15"
 bytes = "0.5"
 toml = "0.5.8"
@@ -21,7 +20,10 @@ log = "0.4.14"
 serde_json = "1.0"
 serde = { version = "1.0", features = ["derive"] }
 serde_repr = "0.1"
+derive_more = {version = "0.99", features = ["display"]}
+protobuf = {version = "2.20.0"}
 flowy-log = { path = "../rust-lib/flowy-log" }
+flowy-user = { path = "../rust-lib/flowy-user" }
 
 [dependencies.sqlx]
 version = "0.5.2"

+ 9 - 2
backend/src/config/config.rs

@@ -1,11 +1,18 @@
-use std::convert::TryFrom;
+use crate::config::DatabaseConfig;
+use std::{convert::TryFrom, sync::Arc};
 
 pub struct Config {
     pub http_port: u16,
+    pub database: Arc<DatabaseConfig>,
 }
 
 impl Config {
-    pub fn new() -> Self { Config { http_port: 3030 } }
+    pub fn new() -> Self {
+        Config {
+            http_port: 3030,
+            database: Arc::new(DatabaseConfig::default()),
+        }
+    }
 
     pub fn server_addr(&self) -> String { format!("0.0.0.0:{}", self.http_port) }
 }

+ 1 - 0
backend/src/config/const_define.rs

@@ -2,3 +2,4 @@ use std::time::Duration;
 
 pub const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(8);
 pub const PING_TIMEOUT: Duration = Duration::from_secs(60);
+pub const MAX_PAYLOAD_SIZE: usize = 262_144; // max payload size is 256k

+ 5 - 0
backend/src/config/database/config.toml

@@ -0,0 +1,5 @@
+host = "localhost"
+port = 5433
+username = "postgres"
+password = "password"
+database_name = "flowy"

+ 33 - 0
backend/src/config/database/database.rs

@@ -0,0 +1,33 @@
+use serde::Deserialize;
+
+#[derive(Deserialize)]
+pub struct DatabaseConfig {
+    username: String,
+    password: String,
+    port: u16,
+    host: String,
+    database_name: String,
+}
+
+impl DatabaseConfig {
+    pub fn connect_url(&self) -> String {
+        format!(
+            "postgres://{}:{}@{}:{}/{}",
+            self.username, self.password, self.host, self.port, self.database_name
+        )
+    }
+
+    pub fn set_env_db_url(&self) {
+        let url = self.connect_url();
+        std::env::set_var("DATABASE_URL", url);
+    }
+}
+
+impl std::default::Default for DatabaseConfig {
+    fn default() -> DatabaseConfig {
+        let toml_str: &str = include_str!("config.toml");
+        let config: DatabaseConfig = toml::from_str(toml_str).unwrap();
+        config.set_env_db_url();
+        config
+    }
+}

+ 3 - 0
backend/src/config/database/mod.rs

@@ -0,0 +1,3 @@
+mod database;
+
+pub use database::*;

+ 2 - 0
backend/src/config/mod.rs

@@ -1,5 +1,7 @@
 mod config;
 mod const_define;
+mod database;
 
 pub use config::*;
 pub use const_define::*;
+pub use database::*;

+ 16 - 7
backend/src/context.rs

@@ -1,19 +1,28 @@
-use crate::{config::Config, ws_service::WSServer};
+use crate::{config::Config, user_service::Auth, ws_service::WSServer};
 use actix::Addr;
+
+use sqlx::PgPool;
 use std::sync::Arc;
 
 pub struct AppContext {
     pub config: Arc<Config>,
-    pub server: Addr<WSServer>,
+    pub ws_server: Addr<WSServer>,
+    pub db_pool: Arc<PgPool>,
+    pub auth: Arc<Auth>,
 }
 
 impl AppContext {
-    pub fn new(server: Addr<WSServer>) -> Self {
+    pub fn new(
+        config: Arc<Config>,
+        ws_server: Addr<WSServer>,
+        db_pool: Arc<PgPool>,
+        auth: Arc<Auth>,
+    ) -> Self {
         AppContext {
-            config: Arc::new(Config::new()),
-            server,
+            config,
+            ws_server,
+            db_pool,
+            auth,
         }
     }
-
-    pub fn ws_server(&self) -> Addr<WSServer> { self.server.clone() }
 }

+ 7 - 0
backend/src/entities/mod.rs

@@ -0,0 +1,7 @@
+mod response;
+mod response_serde;
+mod server_code;
+
+pub use response::*;
+pub use response_serde::*;
+pub use server_code::*;

+ 45 - 0
backend/src/entities/response.rs

@@ -0,0 +1,45 @@
+use crate::{entities::ServerCode, errors::ServerError};
+use actix_web::{body::Body, HttpResponse, ResponseError};
+
+use serde::Serialize;
+
+#[derive(Debug, Serialize)]
+pub struct ServerResponse<T> {
+    pub msg: String,
+    pub data: Option<T>,
+    pub code: ServerCode,
+}
+
+impl<T: Serialize> ServerResponse<T> {
+    pub fn new(data: Option<T>, msg: &str, code: ServerCode) -> Self {
+        ServerResponse {
+            msg: msg.to_owned(),
+            data,
+            code,
+        }
+    }
+
+    pub fn from_data(data: T, msg: &str, code: ServerCode) -> Self {
+        Self::new(Some(data), msg, code)
+    }
+}
+
+impl ServerResponse<String> {
+    pub fn success() -> Self { Self::from_msg("", ServerCode::Success) }
+
+    pub fn from_msg(msg: &str, code: ServerCode) -> Self {
+        Self::new(Some("".to_owned()), msg, code)
+    }
+}
+
+impl<T: Serialize> std::convert::Into<HttpResponse> for ServerResponse<T> {
+    fn into(self) -> HttpResponse {
+        match serde_json::to_string(&self) {
+            Ok(body) => HttpResponse::Ok().body(Body::from(body)),
+            Err(e) => {
+                let msg = format!("Serial error: {:?}", e);
+                ServerError::InternalError(msg).error_response()
+            },
+        }
+    }
+}

+ 128 - 0
backend/src/entities/response_serde.rs

@@ -0,0 +1,128 @@
+use crate::entities::{ServerCode, ServerResponse};
+use serde::{
+    de::{self, MapAccess, Visitor},
+    Deserialize,
+    Deserializer,
+    Serialize,
+};
+use std::{fmt, marker::PhantomData, str::FromStr};
+
+pub trait ServerData<'a>: Serialize + Deserialize<'a> + FromStr<Err = ()> {}
+impl<'de, T: ServerData<'de>> Deserialize<'de> for ServerResponse<T> {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: Deserializer<'de>,
+    {
+        struct ServerResponseVisitor<T>(PhantomData<fn() -> T>);
+        impl<'de, T> Visitor<'de> for ServerResponseVisitor<T>
+        where
+            T: ServerData<'de>,
+        {
+            type Value = ServerResponse<T>;
+
+            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+                formatter.write_str("struct Duration")
+            }
+
+            fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
+            where
+                V: MapAccess<'de>,
+            {
+                let mut msg = None;
+                let mut data: Option<T> = None;
+                let mut code: Option<ServerCode> = None;
+                while let Some(key) = map.next_key()? {
+                    match key {
+                        "msg" => {
+                            if msg.is_some() {
+                                return Err(de::Error::duplicate_field("msg"));
+                            }
+                            msg = Some(map.next_value()?);
+                        },
+                        "code" => {
+                            if code.is_some() {
+                                return Err(de::Error::duplicate_field("code"));
+                            }
+                            code = Some(map.next_value()?);
+                        },
+                        "data" => {
+                            if data.is_some() {
+                                return Err(de::Error::duplicate_field("data"));
+                            }
+                            data = match MapAccess::next_value::<DeserializeWith<T>>(&mut map) {
+                                Ok(wrapper) => wrapper.value,
+                                Err(err) => return Err(err),
+                            };
+                        },
+                        _ => panic!(),
+                    }
+                }
+                let msg = msg.ok_or_else(|| de::Error::missing_field("msg"))?;
+                let code = code.ok_or_else(|| de::Error::missing_field("code"))?;
+                Ok(Self::Value::new(data, msg, code))
+            }
+        }
+        const FIELDS: &'static [&'static str] = &["msg", "code", "data"];
+        deserializer.deserialize_struct(
+            "ServerResponse",
+            FIELDS,
+            ServerResponseVisitor(PhantomData),
+        )
+    }
+}
+
+struct DeserializeWith<'de, T: ServerData<'de>> {
+    value: Option<T>,
+    phantom: PhantomData<&'de ()>,
+}
+
+impl<'de, T: ServerData<'de>> Deserialize<'de> for DeserializeWith<'de, T> {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: Deserializer<'de>,
+    {
+        Ok(DeserializeWith {
+            value: match string_or_data(deserializer) {
+                Ok(val) => val,
+                Err(e) => return Err(e),
+            },
+            phantom: PhantomData,
+        })
+    }
+}
+
+fn string_or_data<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
+where
+    D: Deserializer<'de>,
+    T: ServerData<'de>,
+{
+    struct StringOrData<T>(PhantomData<fn() -> T>);
+    impl<'de, T: ServerData<'de>> Visitor<'de> for StringOrData<T> {
+        type Value = Option<T>;
+
+        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+            formatter.write_str("string or struct impl deserialize")
+        }
+
+        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
+        where
+            E: de::Error,
+        {
+            match FromStr::from_str(value) {
+                Ok(val) => Ok(Some(val)),
+                Err(_e) => Ok(None),
+            }
+        }
+
+        fn visit_map<M>(self, map: M) -> Result<Self::Value, M::Error>
+        where
+            M: MapAccess<'de>,
+        {
+            match Deserialize::deserialize(de::value::MapAccessDeserializer::new(map)) {
+                Ok(val) => Ok(Some(val)),
+                Err(e) => Err(e),
+            }
+        }
+    }
+    deserializer.deserialize_any(StringOrData(PhantomData))
+}

+ 12 - 0
backend/src/entities/server_code.rs

@@ -0,0 +1,12 @@
+use serde_repr::*;
+
+#[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug)]
+#[repr(u16)]
+pub enum ServerCode {
+    Success          = 0,
+    InvalidToken     = 1,
+    InternalError    = 2,
+    Unauthorized     = 3,
+    PayloadOverflow  = 4,
+    PayloadSerdeFail = 5,
+}

+ 47 - 2
backend/src/errors.rs

@@ -1,3 +1,48 @@
-pub struct ServerError {}
+use crate::entities::{ServerCode, ServerResponse};
+use actix_web::{error::ResponseError, HttpResponse};
+use protobuf::ProtobufError;
+use std::fmt::Formatter;
 
-// pub enum ErrorCode {}
+#[derive(Debug)]
+pub enum ServerError {
+    InternalError(String),
+    BadRequest(ServerResponse<String>),
+    Unauthorized,
+}
+
+impl std::fmt::Display for ServerError {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        match self {
+            ServerError::InternalError(_) => f.write_str("Internal Server Error"),
+            ServerError::BadRequest(request) => {
+                let msg = format!("Bad Request: {:?}", request);
+                f.write_str(&msg)
+            },
+            ServerError::Unauthorized => f.write_str("Unauthorized"),
+        }
+    }
+}
+
+impl ResponseError for ServerError {
+    fn error_response(&self) -> HttpResponse {
+        match self {
+            ServerError::InternalError(msg) => {
+                let msg = format!("Internal Server Error. {}", msg);
+                let resp = ServerResponse::from_msg(&msg, ServerCode::InternalError);
+                HttpResponse::InternalServerError().json(resp)
+            },
+            ServerError::BadRequest(ref resp) => HttpResponse::BadRequest().json(resp),
+            ServerError::Unauthorized => {
+                let resp = ServerResponse::from_msg("Unauthorized", ServerCode::Unauthorized);
+                HttpResponse::Unauthorized().json(resp)
+            },
+        }
+    }
+}
+
+impl std::convert::From<ProtobufError> for ServerError {
+    fn from(err: ProtobufError) -> Self {
+        let msg = format!("{:?}", err);
+        ServerError::InternalError(msg)
+    }
+}

+ 2 - 0
backend/src/lib.rs

@@ -1,6 +1,8 @@
 mod config;
 mod context;
+mod entities;
 mod errors;
 mod routers;
 pub mod startup;
+pub mod user_service;
 pub mod ws_service;

+ 34 - 0
backend/src/routers/helper.rs

@@ -0,0 +1,34 @@
+use crate::{
+    config::MAX_PAYLOAD_SIZE,
+    entities::{ServerCode, ServerResponse},
+    errors::ServerError,
+};
+use actix_web::web;
+use futures::StreamExt;
+use protobuf::{Message, ProtobufResult};
+
+pub async fn parse_from_payload<T: Message>(payload: web::Payload) -> Result<T, ServerError> {
+    let bytes = poll_payload(payload).await?;
+    parse_from_bytes(&bytes)
+}
+
+pub fn parse_from_bytes<T: Message>(bytes: &[u8]) -> Result<T, ServerError> {
+    let result: ProtobufResult<T> = Message::parse_from_bytes(&bytes);
+    match result {
+        Ok(data) => Ok(data),
+        Err(e) => Err(e.into()),
+    }
+}
+
+pub async fn poll_payload(mut payload: web::Payload) -> Result<web::BytesMut, ServerError> {
+    let mut body = web::BytesMut::new();
+    while let Some(chunk) = payload.next().await {
+        let chunk = chunk.map_err(|e| ServerError::InternalError(format!("{:?}", e)))?;
+        if (body.len() + chunk.len()) > MAX_PAYLOAD_SIZE {
+            let resp = ServerResponse::from_msg("Payload overflow", ServerCode::PayloadOverflow);
+            return Err(ServerError::BadRequest(resp));
+        }
+        body.extend_from_slice(&chunk);
+    }
+    Ok(body)
+}

+ 3 - 0
backend/src/routers/mod.rs

@@ -1,3 +1,6 @@
+mod helper;
+mod user;
 pub(crate) mod ws;
 
+pub use user::*;
 pub use ws::*;

+ 23 - 0
backend/src/routers/user.rs

@@ -0,0 +1,23 @@
+use crate::user_service::Auth;
+use actix_web::{
+    web::{Data, Payload},
+    Error,
+    HttpRequest,
+    HttpResponse,
+};
+use flowy_user::protobuf::SignUpRequest;
+
+use crate::{entities::ServerResponse, routers::helper::parse_from_payload};
+
+use std::sync::Arc;
+
+pub async fn user_register(
+    request: HttpRequest,
+    payload: Payload,
+    auth: Data<Arc<Auth>>,
+) -> Result<HttpResponse, Error> {
+    let request: SignUpRequest = parse_from_payload(payload).await?;
+    // ProtobufError
+    let resp = ServerResponse::success();
+    Ok(resp.into())
+}

+ 1 - 1
backend/src/routers/ws.rs

@@ -1,6 +1,6 @@
 use crate::ws_service::{entities::SessionId, WSClient, WSServer};
 use actix::Addr;
-use actix_http::{body::Body, Response};
+
 use actix_web::{
     get,
     web::{Data, Path, Payload},

+ 22 - 3
backend/src/startup.rs

@@ -1,6 +1,13 @@
-use crate::{context::AppContext, routers::*, ws_service::WSServer};
+use crate::{
+    config::Config,
+    context::AppContext,
+    routers::*,
+    user_service::Auth,
+    ws_service::WSServer,
+};
 use actix::Actor;
 use actix_web::{dev::Server, middleware, web, App, HttpServer, Scope};
+use sqlx::PgPool;
 use std::{net::TcpListener, sync::Arc};
 
 pub fn run(app_ctx: Arc<AppContext>, listener: TcpListener) -> Result<Server, std::io::Error> {
@@ -9,7 +16,9 @@ pub fn run(app_ctx: Arc<AppContext>, listener: TcpListener) -> Result<Server, st
             .wrap(middleware::Logger::default())
             .data(web::JsonConfig::default().limit(4096))
             .service(ws_scope())
-            .data(app_ctx.ws_server())
+            .data(app_ctx.ws_server.clone())
+            .data(app_ctx.db_pool.clone())
+            .data(app_ctx.auth.clone())
     })
     .listen(listener)?
     .run();
@@ -20,7 +29,17 @@ fn ws_scope() -> Scope { web::scope("/ws").service(ws::start_connection) }
 
 pub async fn init_app_context() -> Arc<AppContext> {
     let _ = flowy_log::Builder::new("flowy").env_filter("Debug").build();
+    let config = Arc::new(Config::new());
+
+    // TODO: what happened when PgPool connect fail?
+    let db_pool = Arc::new(
+        PgPool::connect(&config.database.connect_url())
+            .await
+            .expect("Failed to connect to Postgres."),
+    );
     let ws_server = WSServer::new().start();
-    let ctx = AppContext::new(ws_server);
+    let auth = Arc::new(Auth::new(db_pool.clone()));
+
+    let ctx = AppContext::new(config, ws_server, db_pool, auth);
     Arc::new(ctx)
 }

+ 14 - 0
backend/src/user_service/auth.rs

@@ -0,0 +1,14 @@
+use crate::errors::ServerError;
+use flowy_user::protobuf::SignUpRequest;
+use sqlx::PgPool;
+use std::sync::Arc;
+
+pub struct Auth {
+    db_pool: Arc<PgPool>,
+}
+
+impl Auth {
+    pub fn new(db_pool: Arc<PgPool>) -> Self { Self { db_pool } }
+
+    pub fn handle_sign_up(&self, request: SignUpRequest) -> Result<(), ServerError> { Ok(()) }
+}

+ 3 - 0
backend/src/user_service/mod.rs

@@ -0,0 +1,3 @@
+mod auth;
+
+pub use auth::*;

+ 0 - 1
backend/src/ws_service/ws_client.rs

@@ -16,7 +16,6 @@ use actix::{
     AsyncContext,
     ContextFutureSpawner,
     Handler,
-    Recipient,
     Running,
     StreamHandler,
     WrapFuture,

+ 0 - 1
backend/src/ws_service/ws_server.rs

@@ -3,7 +3,6 @@ use crate::{
     ws_service::{
         entities::{Connect, Disconnect, Session, SessionId},
         ClientMessage,
-        WSClient,
     },
 };
 use actix::{Actor, Context, Handler};

+ 1 - 1
rust-lib/flowy-log/src/lib.rs

@@ -24,7 +24,7 @@ impl Builder {
         self
     }
 
-    pub fn local(mut self, directory: impl AsRef<Path>) -> Self {
+    pub fn local(self, directory: impl AsRef<Path>) -> Self {
         let directory = directory.as_ref().to_str().unwrap().to_owned();
         let local_file_name = format!("{}.log", &self.name);
         let file_appender = tracing_appender::rolling::daily(directory, local_file_name);

+ 1 - 1
rust-lib/flowy-user/src/lib.rs

@@ -3,7 +3,7 @@ pub mod errors;
 pub mod event;
 mod handlers;
 pub mod module;
-mod protobuf;
+pub mod protobuf;
 mod services;
 pub mod sql_tables;