ソースを参照

add auth middleware

appflowy 3 年 前
コミット
cfc10fb38e

+ 2 - 0
backend/Cargo.toml

@@ -15,7 +15,9 @@ actix-web = "4.0.0-beta.8"
 actix-http = "3.0.0-beta.8"
 actix-rt = "2"
 actix-web-actors = { version = "4.0.0-beta.6" }
+actix-service = "2.0.0-beta.5"
 actix-identity = "0.4.0-beta.2"
+#actix-cors = "0.5.4"
 
 futures = "0.3.15"
 bytes = "1"

+ 1 - 0
backend/src/application.rs

@@ -53,6 +53,7 @@ pub fn run(listener: TcpListener, app_ctx: AppContext) -> Result<Server, std::io
         App::new()
             .wrap(middleware::Logger::default())
             .wrap(identify_service(&domain, &secret))
+            .wrap(crate::middleware::AuthenticationService)
             .app_data(web::JsonConfig::default().limit(4096))
             .service(ws_scope())
             .service(user_scope())

+ 2 - 0
backend/src/config/const_define.rs

@@ -3,3 +3,5 @@ use std::time::Duration;
 pub const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(8);
 pub const PING_TIMEOUT: Duration = Duration::from_secs(60);
 pub const MAX_PAYLOAD_SIZE: usize = 262_144; // max payload size is 256k
+
+pub const IGNORE_ROUTES: [&str; 2] = ["/api/register", "/api/auth/"];

+ 1 - 0
backend/src/lib.rs

@@ -2,6 +2,7 @@ pub mod application;
 pub mod config;
 mod context;
 mod entities;
+mod middleware;
 mod routers;
 mod sqlx_ext;
 pub mod user_service;

+ 95 - 0
backend/src/middleware/auth_middleware.rs

@@ -0,0 +1,95 @@
+use crate::user_service::{LoggedUser, AUTHORIZED_USERS};
+use actix_service::{Service, Transform};
+use actix_web::{
+    dev::{ServiceRequest, ServiceResponse},
+    http::{HeaderName, HeaderValue, Method},
+    web::Data,
+    Error,
+    HttpResponse,
+    ResponseError,
+};
+
+use crate::config::IGNORE_ROUTES;
+use actix_web::{body::AnyBody, dev::MessageBody};
+use flowy_net::{config::HEADER_TOKEN, errors::ServerError, response::FlowyResponse};
+use futures::{
+    future::{ok, LocalBoxFuture, Ready},
+    Future,
+};
+use std::{
+    convert::TryInto,
+    error::Error as StdError,
+    pin::Pin,
+    task::{Context, Poll},
+};
+
+pub struct AuthenticationService;
+
+impl<S, B> Transform<S, ServiceRequest> for AuthenticationService
+where
+    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
+    S::Future: 'static,
+    B: MessageBody + 'static,
+    B::Error: StdError,
+{
+    type Response = ServiceResponse;
+    type Error = Error;
+    type Transform = AuthenticationMiddleware<S>;
+    type InitError = ();
+    type Future = Ready<Result<Self::Transform, Self::InitError>>;
+
+    fn new_transform(&self, service: S) -> Self::Future { ok(AuthenticationMiddleware { service }) }
+}
+pub struct AuthenticationMiddleware<S> {
+    service: S,
+}
+
+impl<S, B> Service<ServiceRequest> for AuthenticationMiddleware<S>
+where
+    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
+    S::Future: 'static,
+    B: MessageBody + 'static,
+    B::Error: StdError,
+{
+    type Response = ServiceResponse;
+    type Error = Error;
+    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
+
+    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+        self.service.poll_ready(cx)
+    }
+
+    fn call(&self, mut req: ServiceRequest) -> Self::Future {
+        let mut authenticate_pass: bool = false;
+        for ignore_route in IGNORE_ROUTES.iter() {
+            if req.path().starts_with(ignore_route) {
+                authenticate_pass = true;
+                break;
+            }
+        }
+
+        if !authenticate_pass {
+            if let Some(header) = req.headers().get(HEADER_TOKEN) {
+                let logger_user: LoggedUser = header.try_into().unwrap();
+                if AUTHORIZED_USERS.is_authorized(&logger_user) {
+                    authenticate_pass = true;
+                }
+            }
+        }
+
+        if authenticate_pass {
+            let fut = self.service.call(req);
+            return Box::pin(async move {
+                let res = fut.await?;
+                Ok(res.map_body(|_, body| AnyBody::from_message(body)))
+            });
+        } else {
+            Box::pin(async move { Ok(req.into_response(unauthorized_response())) })
+        }
+    }
+}
+
+fn unauthorized_response() -> HttpResponse {
+    let error = ServerError::unauthorized();
+    error.error_response()
+}

+ 3 - 0
backend/src/middleware/mod.rs

@@ -0,0 +1,3 @@
+mod auth_middleware;
+
+pub use auth_middleware::*;

+ 10 - 0
rust-lib/flowy-net/src/response/response_http.rs

@@ -2,6 +2,7 @@ use crate::response::*;
 use actix_web::{error::ResponseError, HttpResponse};
 
 use crate::errors::ServerError;
+use actix_web::body::AnyBody;
 
 impl ResponseError for ServerError {
     fn error_response(&self) -> HttpResponse {
@@ -12,3 +13,12 @@ impl ResponseError for ServerError {
 impl std::convert::Into<HttpResponse> for FlowyResponse {
     fn into(self) -> HttpResponse { HttpResponse::Ok().json(self) }
 }
+
+impl std::convert::Into<AnyBody> for FlowyResponse {
+    fn into(self) -> AnyBody {
+        match serde_json::to_string(&self) {
+            Ok(body) => AnyBody::from(body),
+            Err(err) => AnyBody::Empty,
+        }
+    }
+}