appflowy il y a 3 ans
Parent
commit
260060ac5c

+ 2 - 0
app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbenum.dart

@@ -13,11 +13,13 @@ class ErrorCode extends $pb.ProtobufEnum {
   static const ErrorCode InternalError = ErrorCode._(0, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'InternalError');
   static const ErrorCode DuplicateSource = ErrorCode._(1, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'DuplicateSource');
   static const ErrorCode UnsupportedMessage = ErrorCode._(2, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'UnsupportedMessage');
+  static const ErrorCode Unauthorized = ErrorCode._(3, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'Unauthorized');
 
   static const $core.List<ErrorCode> values = <ErrorCode> [
     InternalError,
     DuplicateSource,
     UnsupportedMessage,
+    Unauthorized,
   ];
 
   static final $core.Map<$core.int, ErrorCode> _byValue = $pb.ProtobufEnum.initByValue(values);

+ 2 - 1
app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbjson.dart

@@ -15,11 +15,12 @@ const ErrorCode$json = const {
     const {'1': 'InternalError', '2': 0},
     const {'1': 'DuplicateSource', '2': 1},
     const {'1': 'UnsupportedMessage', '2': 2},
+    const {'1': 'Unauthorized', '2': 3},
   ],
 };
 
 /// Descriptor for `ErrorCode`. Decode as a `google.protobuf.EnumDescriptorProto`.
-final $typed_data.Uint8List errorCodeDescriptor = $convert.base64Decode('CglFcnJvckNvZGUSEQoNSW50ZXJuYWxFcnJvchAAEhMKD0R1cGxpY2F0ZVNvdXJjZRABEhYKElVuc3VwcG9ydGVkTWVzc2FnZRAC');
+final $typed_data.Uint8List errorCodeDescriptor = $convert.base64Decode('CglFcnJvckNvZGUSEQoNSW50ZXJuYWxFcnJvchAAEhMKD0R1cGxpY2F0ZVNvdXJjZRABEhYKElVuc3VwcG9ydGVkTWVzc2FnZRACEhAKDFVuYXV0aG9yaXplZBAD');
 @$core.Deprecated('Use wsErrorDescriptor instead')
 const WsError$json = const {
   '1': 'WsError',

+ 1 - 0
backend/Cargo.toml

@@ -83,6 +83,7 @@ name = "backend"
 path = "src/main.rs"
 
 [dev-dependencies]
+parking_lot = "0.11"
 once_cell = "1.7.2"
 linkify = "0.5.0"
 flowy-user = { path = "../rust-lib/flowy-user" }

+ 78 - 5
backend/tests/api/ws.rs

@@ -1,10 +1,83 @@
 use crate::helper::TestServer;
-use flowy_ws::WsController;
+use flowy_ws::{WsController, WsSender, WsState};
+use parking_lot::RwLock;
+use std::sync::Arc;
+
+pub struct WsTest {
+    server: TestServer,
+    ws_controller: Arc<RwLock<WsController>>,
+}
+
+#[derive(Clone)]
+pub enum WsScript {
+    SendText(&'static str),
+    SendBinary(Vec<u8>),
+    Disconnect(&'static str),
+}
+
+impl WsTest {
+    pub async fn new(scripts: Vec<WsScript>) -> Self {
+        let server = TestServer::new().await;
+        let ws_controller = Arc::new(RwLock::new(WsController::new()));
+        ws_controller
+            .write()
+            .state_callback(move |state| match state {
+                WsState::Connected(sender) => {
+                    WsScriptRunner {
+                        scripts: scripts.clone(),
+                        sender: sender.clone(),
+                        source: "editor".to_owned(),
+                    }
+                    .run();
+                },
+                _ => {},
+            })
+            .await;
+
+        Self {
+            server,
+            ws_controller,
+        }
+    }
+
+    pub async fn run_scripts(&mut self) {
+        let addr = self.server.ws_addr();
+        self.ws_controller.write().connect(addr).unwrap().await;
+    }
+}
+
+struct WsScriptRunner {
+    scripts: Vec<WsScript>,
+    sender: Arc<WsSender>,
+    source: String,
+}
+
+impl WsScriptRunner {
+    fn run(self) {
+        for script in self.scripts {
+            match script {
+                WsScript::SendText(text) => {
+                    self.sender.send_text(&self.source, text).unwrap();
+                },
+                WsScript::SendBinary(bytes) => {
+                    self.sender.send_binary(&self.source, bytes).unwrap();
+                },
+                WsScript::Disconnect(reason) => {
+                    self.sender.send_disconnect(reason).unwrap();
+                },
+            }
+        }
+    }
+}
 
 #[actix_rt::test]
 async fn ws_connect() {
-    let server = TestServer::new().await;
-    let mut controller = WsController::new();
-    let addr = server.ws_addr();
-    let _ = controller.connect(addr).unwrap().await;
+    let mut ws = WsTest::new(vec![
+        WsScript::SendText("abc"),
+        WsScript::SendText("abc"),
+        WsScript::SendText("abc"),
+        WsScript::Disconnect("abc"),
+    ])
+    .await;
+    ws.run_scripts().await
 }

+ 10 - 9
rust-lib/flowy-user/src/services/user/user_session.rs

@@ -178,15 +178,16 @@ impl UserSession {
         Ok(())
     }
 
-    pub fn send_ws_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), UserError> {
-        match self.ws_controller.try_read_for(Duration::from_millis(300)) {
-            None => Err(UserError::internal().context("Send ws message timeout")),
-            Some(guard) => {
-                let _ = guard.send_msg(msg)?;
-                Ok(())
-            },
-        }
-    }
+    // pub fn send_ws_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(),
+    // UserError> {     match self.ws_controller.try_read_for(Duration::
+    // from_millis(300)) {         None =>
+    // Err(UserError::internal().context("Send ws message timeout")),
+    //         Some(guard) => {
+    //             let _ = guard.send_msg(msg)?;
+    //             Ok(())
+    //         },
+    //     }
+    // }
 }
 
 impl UserSession {

+ 2 - 17
rust-lib/flowy-ws/src/connect.rs

@@ -37,7 +37,7 @@ impl WsConnection {
 }
 
 impl Future for WsConnection {
-    type Output = Result<WsStream, ServerError>;
+    type Output = Result<WsStream, WsError>;
     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
         // [[pin]]
         // poll async function.  The following methods not work.
@@ -65,7 +65,7 @@ impl Future for WsConnection {
                 },
                 Err(error) => {
                     log::debug!("🐴 ws connect failed: {:?}", error);
-                    Poll::Ready(Err(error_to_flowy_response(error)))
+                    Poll::Ready(Err(error.into()))
                 },
             };
         }
@@ -135,21 +135,6 @@ fn post_message(tx: MsgSender, message: Result<Message, Error>) {
     }
 }
 
-fn error_to_flowy_response(error: tokio_tungstenite::tungstenite::Error) -> ServerError {
-    let error = match error {
-        Error::Http(response) => {
-            if response.status() == StatusCode::UNAUTHORIZED {
-                ServerError::unauthorized()
-            } else {
-                ServerError::internal().context(response)
-            }
-        },
-        _ => ServerError::internal().context(error),
-    };
-
-    error
-}
-
 pub struct Retry<F> {
     f: F,
     retry_time: usize,

+ 17 - 2
rust-lib/flowy-ws/src/errors.rs

@@ -2,7 +2,7 @@ use flowy_derive::{ProtoBuf, ProtoBuf_Enum};
 use futures_channel::mpsc::TrySendError;
 use std::fmt::Debug;
 use strum_macros::Display;
-use tokio_tungstenite::tungstenite::Message;
+use tokio_tungstenite::tungstenite::{http::StatusCode, Message};
 use url::ParseError;
 
 #[derive(Debug, Default, Clone, ProtoBuf)]
@@ -38,6 +38,7 @@ impl WsError {
     static_user_error!(internal, ErrorCode::InternalError);
     static_user_error!(duplicate_source, ErrorCode::DuplicateSource);
     static_user_error!(unsupported_message, ErrorCode::UnsupportedMessage);
+    static_user_error!(unauthorized, ErrorCode::Unauthorized);
 }
 
 #[derive(Debug, Clone, ProtoBuf_Enum, Display, PartialEq, Eq)]
@@ -45,6 +46,7 @@ pub enum ErrorCode {
     InternalError      = 0,
     DuplicateSource    = 1,
     UnsupportedMessage = 2,
+    Unauthorized       = 3,
 }
 
 impl std::default::Default for ErrorCode {
@@ -64,5 +66,18 @@ impl std::convert::From<futures_channel::mpsc::TrySendError<Message>> for WsErro
 }
 
 impl std::convert::From<tokio_tungstenite::tungstenite::Error> for WsError {
-    fn from(error: tokio_tungstenite::tungstenite::Error) -> Self { WsError::internal().context(error) }
+    fn from(error: tokio_tungstenite::tungstenite::Error) -> Self {
+        let error = match error {
+            tokio_tungstenite::tungstenite::Error::Http(response) => {
+                if response.status() == StatusCode::UNAUTHORIZED {
+                    WsError::unauthorized()
+                } else {
+                    WsError::internal().context(response)
+                }
+            },
+            _ => WsError::internal().context(error),
+        };
+
+        error
+    }
 }

+ 22 - 17
rust-lib/flowy-ws/src/protobuf/model/errors.rs

@@ -218,6 +218,7 @@ pub enum ErrorCode {
     InternalError = 0,
     DuplicateSource = 1,
     UnsupportedMessage = 2,
+    Unauthorized = 3,
 }
 
 impl ::protobuf::ProtobufEnum for ErrorCode {
@@ -230,6 +231,7 @@ impl ::protobuf::ProtobufEnum for ErrorCode {
             0 => ::std::option::Option::Some(ErrorCode::InternalError),
             1 => ::std::option::Option::Some(ErrorCode::DuplicateSource),
             2 => ::std::option::Option::Some(ErrorCode::UnsupportedMessage),
+            3 => ::std::option::Option::Some(ErrorCode::Unauthorized),
             _ => ::std::option::Option::None
         }
     }
@@ -239,6 +241,7 @@ impl ::protobuf::ProtobufEnum for ErrorCode {
             ErrorCode::InternalError,
             ErrorCode::DuplicateSource,
             ErrorCode::UnsupportedMessage,
+            ErrorCode::Unauthorized,
         ];
         values
     }
@@ -268,24 +271,26 @@ impl ::protobuf::reflect::ProtobufValue for ErrorCode {
 
 static file_descriptor_proto_data: &'static [u8] = b"\
     \n\x0cerrors.proto\";\n\x07WsError\x12\x1e\n\x04code\x18\x01\x20\x01(\
-    \x0e2\n.ErrorCodeR\x04code\x12\x10\n\x03msg\x18\x02\x20\x01(\tR\x03msg*K\
+    \x0e2\n.ErrorCodeR\x04code\x12\x10\n\x03msg\x18\x02\x20\x01(\tR\x03msg*]\
     \n\tErrorCode\x12\x11\n\rInternalError\x10\0\x12\x13\n\x0fDuplicateSourc\
-    e\x10\x01\x12\x16\n\x12UnsupportedMessage\x10\x02J\xab\x02\n\x06\x12\x04\
-    \0\0\n\x01\n\x08\n\x01\x0c\x12\x03\0\0\x12\n\n\n\x02\x04\0\x12\x04\x02\0\
-    \x05\x01\n\n\n\x03\x04\0\x01\x12\x03\x02\x08\x0f\n\x0b\n\x04\x04\0\x02\0\
-    \x12\x03\x03\x04\x17\n\x0c\n\x05\x04\0\x02\0\x06\x12\x03\x03\x04\r\n\x0c\
-    \n\x05\x04\0\x02\0\x01\x12\x03\x03\x0e\x12\n\x0c\n\x05\x04\0\x02\0\x03\
-    \x12\x03\x03\x15\x16\n\x0b\n\x04\x04\0\x02\x01\x12\x03\x04\x04\x13\n\x0c\
-    \n\x05\x04\0\x02\x01\x05\x12\x03\x04\x04\n\n\x0c\n\x05\x04\0\x02\x01\x01\
-    \x12\x03\x04\x0b\x0e\n\x0c\n\x05\x04\0\x02\x01\x03\x12\x03\x04\x11\x12\n\
-    \n\n\x02\x05\0\x12\x04\x06\0\n\x01\n\n\n\x03\x05\0\x01\x12\x03\x06\x05\
-    \x0e\n\x0b\n\x04\x05\0\x02\0\x12\x03\x07\x04\x16\n\x0c\n\x05\x05\0\x02\0\
-    \x01\x12\x03\x07\x04\x11\n\x0c\n\x05\x05\0\x02\0\x02\x12\x03\x07\x14\x15\
-    \n\x0b\n\x04\x05\0\x02\x01\x12\x03\x08\x04\x18\n\x0c\n\x05\x05\0\x02\x01\
-    \x01\x12\x03\x08\x04\x13\n\x0c\n\x05\x05\0\x02\x01\x02\x12\x03\x08\x16\
-    \x17\n\x0b\n\x04\x05\0\x02\x02\x12\x03\t\x04\x1b\n\x0c\n\x05\x05\0\x02\
-    \x02\x01\x12\x03\t\x04\x16\n\x0c\n\x05\x05\0\x02\x02\x02\x12\x03\t\x19\
-    \x1ab\x06proto3\
+    e\x10\x01\x12\x16\n\x12UnsupportedMessage\x10\x02\x12\x10\n\x0cUnauthori\
+    zed\x10\x03J\xd4\x02\n\x06\x12\x04\0\0\x0b\x01\n\x08\n\x01\x0c\x12\x03\0\
+    \0\x12\n\n\n\x02\x04\0\x12\x04\x02\0\x05\x01\n\n\n\x03\x04\0\x01\x12\x03\
+    \x02\x08\x0f\n\x0b\n\x04\x04\0\x02\0\x12\x03\x03\x04\x17\n\x0c\n\x05\x04\
+    \0\x02\0\x06\x12\x03\x03\x04\r\n\x0c\n\x05\x04\0\x02\0\x01\x12\x03\x03\
+    \x0e\x12\n\x0c\n\x05\x04\0\x02\0\x03\x12\x03\x03\x15\x16\n\x0b\n\x04\x04\
+    \0\x02\x01\x12\x03\x04\x04\x13\n\x0c\n\x05\x04\0\x02\x01\x05\x12\x03\x04\
+    \x04\n\n\x0c\n\x05\x04\0\x02\x01\x01\x12\x03\x04\x0b\x0e\n\x0c\n\x05\x04\
+    \0\x02\x01\x03\x12\x03\x04\x11\x12\n\n\n\x02\x05\0\x12\x04\x06\0\x0b\x01\
+    \n\n\n\x03\x05\0\x01\x12\x03\x06\x05\x0e\n\x0b\n\x04\x05\0\x02\0\x12\x03\
+    \x07\x04\x16\n\x0c\n\x05\x05\0\x02\0\x01\x12\x03\x07\x04\x11\n\x0c\n\x05\
+    \x05\0\x02\0\x02\x12\x03\x07\x14\x15\n\x0b\n\x04\x05\0\x02\x01\x12\x03\
+    \x08\x04\x18\n\x0c\n\x05\x05\0\x02\x01\x01\x12\x03\x08\x04\x13\n\x0c\n\
+    \x05\x05\0\x02\x01\x02\x12\x03\x08\x16\x17\n\x0b\n\x04\x05\0\x02\x02\x12\
+    \x03\t\x04\x1b\n\x0c\n\x05\x05\0\x02\x02\x01\x12\x03\t\x04\x16\n\x0c\n\
+    \x05\x05\0\x02\x02\x02\x12\x03\t\x19\x1a\n\x0b\n\x04\x05\0\x02\x03\x12\
+    \x03\n\x04\x15\n\x0c\n\x05\x05\0\x02\x03\x01\x12\x03\n\x04\x10\n\x0c\n\
+    \x05\x05\0\x02\x03\x02\x12\x03\n\x13\x14b\x06proto3\
 ";
 
 static file_descriptor_proto_lazy: ::protobuf::rt::LazyV2<::protobuf::descriptor::FileDescriptorProto> = ::protobuf::rt::LazyV2::INIT;

+ 1 - 0
rust-lib/flowy-ws/src/protobuf/proto/errors.proto

@@ -8,4 +8,5 @@ enum ErrorCode {
     InternalError = 0;
     DuplicateSource = 1;
     UnsupportedMessage = 2;
+    Unauthorized = 3;
 }

+ 100 - 20
rust-lib/flowy-ws/src/ws.rs

@@ -4,6 +4,7 @@ use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
 use futures_core::{ready, Stream};
 
 use crate::connect::Retry;
+use bytes::Buf;
 use futures_core::future::BoxFuture;
 use pin_project::pin_project;
 use std::{
@@ -14,8 +15,15 @@ use std::{
     sync::Arc,
     task::{Context, Poll},
 };
-use tokio::task::JoinHandle;
-use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
+use tokio::{sync::RwLock, task::JoinHandle};
+use tokio_tungstenite::{
+    tungstenite::{
+        protocol::{frame::coding::CloseCode, CloseFrame},
+        Message,
+    },
+    MaybeTlsStream,
+    WebSocketStream,
+};
 
 pub type MsgReceiver = UnboundedReceiver<Message>;
 pub type MsgSender = UnboundedSender<Message>;
@@ -24,22 +32,58 @@ pub trait WsMessageHandler: Sync + Send + 'static {
     fn receive_message(&self, msg: WsMessage);
 }
 
+type NotifyCallback = Arc<dyn Fn(&WsState) + Send + Sync + 'static>;
+struct WsStateNotify {
+    #[allow(dead_code)]
+    state: WsState,
+    callback: Option<NotifyCallback>,
+}
+
+impl WsStateNotify {
+    fn update_state(&mut self, state: WsState) {
+        if let Some(f) = &self.callback {
+            f(&state);
+        }
+        self.state = state;
+    }
+}
+
+pub enum WsState {
+    Init,
+    Connected(Arc<WsSender>),
+    Disconnected(WsError),
+}
+
 pub struct WsController {
-    sender: Option<Arc<WsSender>>,
     handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
+    state_notify: Arc<RwLock<WsStateNotify>>,
     addr: Option<String>,
+    sender: Option<Arc<WsSender>>,
 }
 
 impl WsController {
     pub fn new() -> Self {
+        let state_notify = Arc::new(RwLock::new(WsStateNotify {
+            state: WsState::Init,
+            callback: None,
+        }));
+
         let controller = Self {
-            sender: None,
             handlers: HashMap::new(),
+            state_notify,
             addr: None,
+            sender: None,
         };
         controller
     }
 
+    pub async fn state_callback<SC>(&self, callback: SC)
+    where
+        SC: Fn(&WsState) + Send + Sync + 'static,
+    {
+        (self.state_notify.write().await).callback = Some(Arc::new(callback));
+    }
+
     pub fn add_handler(&mut self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
         let source = handler.source();
         if self.handlers.contains_key(&source) {
@@ -61,9 +105,12 @@ impl WsController {
     fn _connect(&mut self, addr: String, retry: Option<BoxFuture<'static, ()>>) -> Result<JoinHandle<()>, ServerError> {
         log::debug!("🐴 ws connect: {}", &addr);
         let (connection, handlers) = self.make_connect(addr.clone());
+        let state_notify = self.state_notify.clone();
+        let sender = self.sender.clone().expect("Sender should be not empty after calling make_connect");
         Ok(tokio::spawn(async move {
             match connection.await {
                 Ok(stream) => {
+                    state_notify.write().await.update_state(WsState::Connected(sender));
                     tokio::select! {
                         result = stream => {
                             match result {
@@ -71,17 +118,19 @@ impl WsController {
                                 Err(e) => {
                                     // TODO: retry?
                                     log::error!("ws stream error {:?}", e);
+                                    state_notify.write().await.update_state(WsState::Disconnected(e));
                                 }
                             }
                         },
                         result = handlers => log::debug!("handlers completed {:?}", result),
                     };
                 },
-                Err(e) => match retry {
-                    None => log::error!("ws connect {} failed {:?}", addr, e),
-                    Some(retry) => {
+                Err(e) => {
+                    log::error!("ws connect {} failed {:?}", addr, e);
+                    state_notify.write().await.update_state(WsState::Disconnected(e));
+                    if let Some(retry) = retry {
                         tokio::spawn(retry);
-                    },
+                    }
                 },
             }
         }))
@@ -101,17 +150,10 @@ impl WsController {
         let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
         let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
         let handlers = self.handlers.clone();
-        self.sender = Some(Arc::new(WsSender::new(ws_tx)));
+        self.sender = Some(Arc::new(WsSender { ws_tx }));
         self.addr = Some(addr.clone());
         (WsConnection::new(msg_tx, ws_rx, addr), WsHandlers::new(handlers, msg_rx))
     }
-
-    pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
-        match self.sender.as_ref() {
-            None => Err(WsError::internal().context("Should call make_connect first")),
-            Some(sender) => sender.send(msg.into()),
-        }
-    }
 }
 
 #[pin_project]
@@ -146,17 +188,55 @@ impl Future for WsHandlers {
     }
 }
 
-struct WsSender {
+// impl WsSender for WsController {
+//     fn send_msg(&self, msg: WsMessage) -> Result<(), WsError> {
+//         match self.ws_tx.as_ref() {
+//             None => Err(WsError::internal().context("Should call make_connect
+// first")),             Some(sender) => {
+//                 let _ = sender.unbounded_send(msg.into()).map_err(|e|
+// WsError::internal().context(e))?;                 Ok(())
+//             },
+//         }
+//     }
+// }
+
+#[derive(Debug, Clone)]
+pub struct WsSender {
     ws_tx: MsgSender,
 }
 
 impl WsSender {
-    pub fn new(ws_tx: MsgSender) -> Self { Self { ws_tx } }
-
-    pub fn send(&self, msg: WsMessage) -> Result<(), WsError> {
+    pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
+        let msg = msg.into();
         let _ = self.ws_tx.unbounded_send(msg.into()).map_err(|e| WsError::internal().context(e))?;
         Ok(())
     }
+
+    pub fn send_text(&self, source: &str, text: &str) -> Result<(), WsError> {
+        let msg = WsMessage {
+            source: source.to_string(),
+            data: text.as_bytes().to_vec(),
+        };
+        self.send_msg(msg)
+    }
+
+    pub fn send_binary(&self, source: &str, bytes: Vec<u8>) -> Result<(), WsError> {
+        let msg = WsMessage {
+            source: source.to_string(),
+            data: bytes,
+        };
+        self.send_msg(msg)
+    }
+
+    pub fn send_disconnect(&self, reason: &str) -> Result<(), WsError> {
+        let frame = CloseFrame {
+            code: CloseCode::Normal,
+            reason: reason.to_owned().into(),
+        };
+        let msg = Message::Close(Some(frame));
+        let _ = self.ws_tx.unbounded_send(msg).map_err(|e| WsError::internal().context(e))?;
+        Ok(())
+    }
 }
 
 #[cfg(test)]