Browse Source

fix async future bug

appflowy 3 years ago
parent
commit
2b32f2111f
27 changed files with 790 additions and 208 deletions
  1. 4 0
      app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbenum.dart
  2. 3 1
      app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbjson.dart
  3. 72 0
      app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pb.dart
  4. 7 0
      app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbenum.dart
  5. 21 0
      app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbjson.dart
  6. 9 0
      app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbserver.dart
  7. 1 0
      app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/protobuf.dart
  8. 1 1
      backend/src/middleware/auth_middleware.rs
  9. 27 21
      backend/tests/api/helper.rs
  10. 1 1
      backend/tests/api/ws.rs
  11. 1 1
      rust-lib/flowy-ast/src/ty_ext.rs
  12. 1 0
      rust-lib/flowy-derive/src/derive_cache/derive_cache.rs
  13. 0 1
      rust-lib/flowy-dispatch/tests/api/module.rs
  14. 11 11
      rust-lib/flowy-net/src/config.rs
  15. 4 0
      rust-lib/flowy-user/src/errors.rs
  16. 21 19
      rust-lib/flowy-user/src/services/user/user_session.rs
  17. 1 1
      rust-lib/flowy-ws/Flowy.toml
  18. 207 0
      rust-lib/flowy-ws/src/connect.rs
  19. 9 1
      rust-lib/flowy-ws/src/errors.rs
  20. 3 0
      rust-lib/flowy-ws/src/lib.rs
  21. 38 0
      rust-lib/flowy-ws/src/msg.rs
  22. 24 13
      rust-lib/flowy-ws/src/protobuf/model/errors.rs
  23. 3 0
      rust-lib/flowy-ws/src/protobuf/model/mod.rs
  24. 250 0
      rust-lib/flowy-ws/src/protobuf/model/msg.rs
  25. 2 0
      rust-lib/flowy-ws/src/protobuf/proto/errors.proto
  26. 6 0
      rust-lib/flowy-ws/src/protobuf/proto/msg.proto
  27. 63 137
      rust-lib/flowy-ws/src/ws.rs

+ 4 - 0
app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbenum.dart

@@ -11,9 +11,13 @@ import 'package:protobuf/protobuf.dart' as $pb;
 
 class ErrorCode extends $pb.ProtobufEnum {
   static const ErrorCode InternalError = ErrorCode._(0, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'InternalError');
+  static const ErrorCode DuplicateSource = ErrorCode._(1, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'DuplicateSource');
+  static const ErrorCode UnsupportedMessage = ErrorCode._(2, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'UnsupportedMessage');
 
   static const $core.List<ErrorCode> values = <ErrorCode> [
     InternalError,
+    DuplicateSource,
+    UnsupportedMessage,
   ];
 
   static final $core.Map<$core.int, ErrorCode> _byValue = $pb.ProtobufEnum.initByValue(values);

+ 3 - 1
app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbjson.dart

@@ -13,11 +13,13 @@ const ErrorCode$json = const {
   '1': 'ErrorCode',
   '2': const [
     const {'1': 'InternalError', '2': 0},
+    const {'1': 'DuplicateSource', '2': 1},
+    const {'1': 'UnsupportedMessage', '2': 2},
   ],
 };
 
 /// Descriptor for `ErrorCode`. Decode as a `google.protobuf.EnumDescriptorProto`.
-final $typed_data.Uint8List errorCodeDescriptor = $convert.base64Decode('CglFcnJvckNvZGUSEQoNSW50ZXJuYWxFcnJvchAA');
+final $typed_data.Uint8List errorCodeDescriptor = $convert.base64Decode('CglFcnJvckNvZGUSEQoNSW50ZXJuYWxFcnJvchAAEhMKD0R1cGxpY2F0ZVNvdXJjZRABEhYKElVuc3VwcG9ydGVkTWVzc2FnZRAC');
 @$core.Deprecated('Use wsErrorDescriptor instead')
 const WsError$json = const {
   '1': 'WsError',

+ 72 - 0
app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pb.dart

@@ -0,0 +1,72 @@
+///
+//  Generated code. Do not modify.
+//  source: msg.proto
+//
+// @dart = 2.12
+// ignore_for_file: annotate_overrides,camel_case_types,unnecessary_const,non_constant_identifier_names,library_prefixes,unused_import,unused_shown_name,return_of_invalid_type,unnecessary_this,prefer_final_fields
+
+import 'dart:core' as $core;
+
+import 'package:protobuf/protobuf.dart' as $pb;
+
+class WsMessage extends $pb.GeneratedMessage {
+  static final $pb.BuilderInfo _i = $pb.BuilderInfo(const $core.bool.fromEnvironment('protobuf.omit_message_names') ? '' : 'WsMessage', createEmptyInstance: create)
+    ..aOS(1, const $core.bool.fromEnvironment('protobuf.omit_field_names') ? '' : 'source')
+    ..a<$core.List<$core.int>>(2, const $core.bool.fromEnvironment('protobuf.omit_field_names') ? '' : 'data', $pb.PbFieldType.OY)
+    ..hasRequiredFields = false
+  ;
+
+  WsMessage._() : super();
+  factory WsMessage({
+    $core.String? source,
+    $core.List<$core.int>? data,
+  }) {
+    final _result = create();
+    if (source != null) {
+      _result.source = source;
+    }
+    if (data != null) {
+      _result.data = data;
+    }
+    return _result;
+  }
+  factory WsMessage.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r);
+  factory WsMessage.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r);
+  @$core.Deprecated(
+  'Using this can add significant overhead to your binary. '
+  'Use [GeneratedMessageGenericExtensions.deepCopy] instead. '
+  'Will be removed in next major version')
+  WsMessage clone() => WsMessage()..mergeFromMessage(this);
+  @$core.Deprecated(
+  'Using this can add significant overhead to your binary. '
+  'Use [GeneratedMessageGenericExtensions.rebuild] instead. '
+  'Will be removed in next major version')
+  WsMessage copyWith(void Function(WsMessage) updates) => super.copyWith((message) => updates(message as WsMessage)) as WsMessage; // ignore: deprecated_member_use
+  $pb.BuilderInfo get info_ => _i;
+  @$core.pragma('dart2js:noInline')
+  static WsMessage create() => WsMessage._();
+  WsMessage createEmptyInstance() => create();
+  static $pb.PbList<WsMessage> createRepeated() => $pb.PbList<WsMessage>();
+  @$core.pragma('dart2js:noInline')
+  static WsMessage getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor<WsMessage>(create);
+  static WsMessage? _defaultInstance;
+
+  @$pb.TagNumber(1)
+  $core.String get source => $_getSZ(0);
+  @$pb.TagNumber(1)
+  set source($core.String v) { $_setString(0, v); }
+  @$pb.TagNumber(1)
+  $core.bool hasSource() => $_has(0);
+  @$pb.TagNumber(1)
+  void clearSource() => clearField(1);
+
+  @$pb.TagNumber(2)
+  $core.List<$core.int> get data => $_getN(1);
+  @$pb.TagNumber(2)
+  set data($core.List<$core.int> v) { $_setBytes(1, v); }
+  @$pb.TagNumber(2)
+  $core.bool hasData() => $_has(1);
+  @$pb.TagNumber(2)
+  void clearData() => clearField(2);
+}
+

+ 7 - 0
app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbenum.dart

@@ -0,0 +1,7 @@
+///
+//  Generated code. Do not modify.
+//  source: msg.proto
+//
+// @dart = 2.12
+// ignore_for_file: annotate_overrides,camel_case_types,unnecessary_const,non_constant_identifier_names,library_prefixes,unused_import,unused_shown_name,return_of_invalid_type,unnecessary_this,prefer_final_fields
+

+ 21 - 0
app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbjson.dart

@@ -0,0 +1,21 @@
+///
+//  Generated code. Do not modify.
+//  source: msg.proto
+//
+// @dart = 2.12
+// ignore_for_file: annotate_overrides,camel_case_types,unnecessary_const,non_constant_identifier_names,library_prefixes,unused_import,unused_shown_name,return_of_invalid_type,unnecessary_this,prefer_final_fields,deprecated_member_use_from_same_package
+
+import 'dart:core' as $core;
+import 'dart:convert' as $convert;
+import 'dart:typed_data' as $typed_data;
+@$core.Deprecated('Use wsMessageDescriptor instead')
+const WsMessage$json = const {
+  '1': 'WsMessage',
+  '2': const [
+    const {'1': 'source', '3': 1, '4': 1, '5': 9, '10': 'source'},
+    const {'1': 'data', '3': 2, '4': 1, '5': 12, '10': 'data'},
+  ],
+};
+
+/// Descriptor for `WsMessage`. Decode as a `google.protobuf.DescriptorProto`.
+final $typed_data.Uint8List wsMessageDescriptor = $convert.base64Decode('CglXc01lc3NhZ2USFgoGc291cmNlGAEgASgJUgZzb3VyY2USEgoEZGF0YRgCIAEoDFIEZGF0YQ==');

+ 9 - 0
app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbserver.dart

@@ -0,0 +1,9 @@
+///
+//  Generated code. Do not modify.
+//  source: msg.proto
+//
+// @dart = 2.12
+// ignore_for_file: annotate_overrides,camel_case_types,unnecessary_const,non_constant_identifier_names,library_prefixes,unused_import,unused_shown_name,return_of_invalid_type,unnecessary_this,prefer_final_fields,deprecated_member_use_from_same_package
+
+export 'msg.pb.dart';
+

+ 1 - 0
app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/protobuf.dart

@@ -1,2 +1,3 @@
 // Auto-generated, do not edit 
 export './errors.pb.dart';
+export './msg.pb.dart';

+ 1 - 1
backend/src/middleware/auth_middleware.rs

@@ -56,7 +56,7 @@ where
     fn call(&self, req: ServiceRequest) -> Self::Future {
         let mut authenticate_pass: bool = false;
         for ignore_route in IGNORE_ROUTES.iter() {
-            log::info!("ignore: {}, path: {}", ignore_route, req.path());
+            // log::info!("ignore: {}, path: {}", ignore_route, req.path());
             if req.path().starts_with(ignore_route) {
                 authenticate_pass = true;
                 break;

+ 27 - 21
backend/tests/api/helper.rs

@@ -13,7 +13,7 @@ use sqlx::{Connection, Executor, PgConnection, PgPool};
 use uuid::Uuid;
 
 pub struct TestServer {
-    pub address: String,
+    pub host: String,
     pub port: u16,
     pub pg_pool: PgPool,
     pub user_token: Option<String>,
@@ -30,12 +30,12 @@ impl TestServer {
     }
 
     pub async fn sign_in(&self, params: SignInParams) -> Result<SignInResponse, UserError> {
-        let url = format!("{}/api/auth", self.address);
+        let url = format!("{}/api/auth", self.http_addr());
         user_sign_in_request(params, &url).await
     }
 
     pub async fn sign_out(&self) {
-        let url = format!("{}/api/auth", self.address);
+        let url = format!("{}/api/auth", self.http_addr());
         let _ = user_sign_out_request(self.user_token(), &url)
             .await
             .unwrap();
@@ -54,7 +54,7 @@ impl TestServer {
     }
 
     pub async fn get_user_profile(&self) -> UserProfile {
-        let url = format!("{}/api/user", self.address);
+        let url = format!("{}/api/user", self.http_addr());
         let user_profile = get_user_profile_request(self.user_token(), &url)
             .await
             .unwrap();
@@ -62,12 +62,12 @@ impl TestServer {
     }
 
     pub async fn update_user_profile(&self, params: UpdateUserParams) -> Result<(), UserError> {
-        let url = format!("{}/api/user", self.address);
+        let url = format!("{}/api/user", self.http_addr());
         update_user_profile_request(self.user_token(), params, &url).await
     }
 
     pub async fn create_workspace(&self, params: CreateWorkspaceParams) -> Workspace {
-        let url = format!("{}/api/workspace", self.address);
+        let url = format!("{}/api/workspace", self.http_addr());
         let workspace = create_workspace_request(self.user_token(), params, &url)
             .await
             .unwrap();
@@ -75,7 +75,7 @@ impl TestServer {
     }
 
     pub async fn read_workspaces(&self, params: QueryWorkspaceParams) -> RepeatedWorkspace {
-        let url = format!("{}/api/workspace", self.address);
+        let url = format!("{}/api/workspace", self.http_addr());
         let workspaces = read_workspaces_request(self.user_token(), params, &url)
             .await
             .unwrap();
@@ -83,21 +83,21 @@ impl TestServer {
     }
 
     pub async fn update_workspace(&self, params: UpdateWorkspaceParams) {
-        let url = format!("{}/api/workspace", self.address);
+        let url = format!("{}/api/workspace", self.http_addr());
         update_workspace_request(self.user_token(), params, &url)
             .await
             .unwrap();
     }
 
     pub async fn delete_workspace(&self, params: DeleteWorkspaceParams) {
-        let url = format!("{}/api/workspace", self.address);
+        let url = format!("{}/api/workspace", self.http_addr());
         delete_workspace_request(self.user_token(), params, &url)
             .await
             .unwrap();
     }
 
     pub async fn create_app(&self, params: CreateAppParams) -> App {
-        let url = format!("{}/api/app", self.address);
+        let url = format!("{}/api/app", self.http_addr());
         let app = create_app_request(self.user_token(), params, &url)
             .await
             .unwrap();
@@ -105,7 +105,7 @@ impl TestServer {
     }
 
     pub async fn read_app(&self, params: QueryAppParams) -> Option<App> {
-        let url = format!("{}/api/app", self.address);
+        let url = format!("{}/api/app", self.http_addr());
         let app = read_app_request(self.user_token(), params, &url)
             .await
             .unwrap();
@@ -113,21 +113,21 @@ impl TestServer {
     }
 
     pub async fn update_app(&self, params: UpdateAppParams) {
-        let url = format!("{}/api/app", self.address);
+        let url = format!("{}/api/app", self.http_addr());
         update_app_request(self.user_token(), params, &url)
             .await
             .unwrap();
     }
 
     pub async fn delete_app(&self, params: DeleteAppParams) {
-        let url = format!("{}/api/app", self.address);
+        let url = format!("{}/api/app", self.http_addr());
         delete_app_request(self.user_token(), params, &url)
             .await
             .unwrap();
     }
 
     pub async fn create_view(&self, params: CreateViewParams) -> View {
-        let url = format!("{}/api/view", self.address);
+        let url = format!("{}/api/view", self.http_addr());
         let view = create_view_request(self.user_token(), params, &url)
             .await
             .unwrap();
@@ -135,7 +135,7 @@ impl TestServer {
     }
 
     pub async fn read_view(&self, params: QueryViewParams) -> Option<View> {
-        let url = format!("{}/api/view", self.address);
+        let url = format!("{}/api/view", self.http_addr());
         let view = read_view_request(self.user_token(), params, &url)
             .await
             .unwrap();
@@ -143,21 +143,21 @@ impl TestServer {
     }
 
     pub async fn update_view(&self, params: UpdateViewParams) {
-        let url = format!("{}/api/view", self.address);
+        let url = format!("{}/api/view", self.http_addr());
         update_view_request(self.user_token(), params, &url)
             .await
             .unwrap();
     }
 
     pub async fn delete_view(&self, params: DeleteViewParams) {
-        let url = format!("{}/api/view", self.address);
+        let url = format!("{}/api/view", self.http_addr());
         delete_view_request(self.user_token(), params, &url)
             .await
             .unwrap();
     }
 
     pub async fn read_doc(&self, params: QueryDocParams) -> Option<Doc> {
-        let url = format!("{}/api/doc", self.address);
+        let url = format!("{}/api/doc", self.http_addr());
         let doc = read_doc_request(self.user_token(), params, &url)
             .await
             .unwrap();
@@ -175,13 +175,19 @@ impl TestServer {
     }
 
     pub(crate) async fn register(&self, params: SignUpParams) -> SignUpResponse {
-        let url = format!("{}/api/register", self.address);
+        let url = format!("{}/api/register", self.http_addr());
         let response = user_sign_up_request(params, &url).await.unwrap();
         response
     }
 
+    pub(crate) fn http_addr(&self) -> String { format!("http://{}", self.host) }
+
     pub(crate) fn ws_addr(&self) -> String {
-        format!("{}/ws/{}", self.address, self.user_token.as_ref().unwrap())
+        format!(
+            "ws://{}/ws/{}",
+            self.host,
+            self.user_token.as_ref().unwrap()
+        )
     }
 }
 pub async fn spawn_server() -> TestServer {
@@ -206,7 +212,7 @@ pub async fn spawn_server() -> TestServer {
     });
 
     TestServer {
-        address: format!("http://localhost:{}", application_port),
+        host: format!("localhost:{}", application_port),
         port: application_port,
         pg_pool: get_connection_pool(&configuration.database)
             .await

+ 1 - 1
backend/tests/api/ws.rs

@@ -6,5 +6,5 @@ async fn ws_connect() {
     let server = TestServer::new().await;
     let mut controller = WsController::new();
     let addr = server.ws_addr();
-    let _ = controller.connect(addr).await.unwrap();
+    let _ = controller.connect(addr).unwrap().await;
 }

+ 1 - 1
rust-lib/flowy-ast/src/ty_ext.rs

@@ -60,7 +60,7 @@ pub fn parse_ty<'a>(ctxt: &Ctxt, ty: &'a syn::Type) -> Option<TyInfo<'a>> {
                 "Vec" => generate_vec_ty_info(ctxt, seg, bracketed),
                 "Option" => generate_option_ty_info(ctxt, ty, seg, bracketed),
                 _ => {
-                    panic!("Unsupported ty")
+                    panic!("Unsupported ty {}", seg.ident.to_string())
                 },
             }
         } else {

+ 1 - 0
rust-lib/flowy-derive/src/derive_cache/derive_cache.rs

@@ -54,6 +54,7 @@ pub fn category_from_str(type_str: &str) -> TypeCategory {
         | "RepeatedView"
         | "WorkspaceError"
         | "WsError"
+        | "WsMessage"
         | "CreateDocParams"
         | "Doc"
         | "SaveDocParams"

+ 0 - 1
rust-lib/flowy-dispatch/tests/api/module.rs

@@ -1,4 +1,3 @@
-use crate::helper::*;
 use flowy_dispatch::prelude::*;
 use std::sync::Arc;
 

+ 11 - 11
rust-lib/flowy-net/src/config.rs

@@ -1,20 +1,20 @@
 use lazy_static::lazy_static;
 
-pub const HOST: &'static str = "http://localhost:8000";
-
+pub const HOST: &'static str = "localhost:8000";
+pub const SCHEMA: &'static str = "http://";
 pub const HEADER_TOKEN: &'static str = "token";
 
 lazy_static! {
-    pub static ref SIGN_UP_URL: String = format!("{}/api/register", HOST);
-    pub static ref SIGN_IN_URL: String = format!("{}/api/auth", HOST);
-    pub static ref SIGN_OUT_URL: String = format!("{}/api/auth", HOST);
-    pub static ref USER_PROFILE_URL: String = format!("{}/api/user", HOST);
+    pub static ref SIGN_UP_URL: String = format!("{}/{}/api/register", SCHEMA, HOST);
+    pub static ref SIGN_IN_URL: String = format!("{}/{}/api/auth", SCHEMA, HOST);
+    pub static ref SIGN_OUT_URL: String = format!("{}/{}/api/auth", SCHEMA, HOST);
+    pub static ref USER_PROFILE_URL: String = format!("{}/{}/api/user", SCHEMA, HOST);
 
     //
-    pub static ref WORKSPACE_URL: String = format!("{}/api/workspace", HOST);
-    pub static ref APP_URL: String = format!("{}/api/app", HOST);
-    pub static ref VIEW_URL: String = format!("{}/api/view", HOST);
-    pub static ref DOC_URL: String = format!("{}/api/doc", HOST);
+    pub static ref WORKSPACE_URL: String = format!("{}/{}/api/workspace", SCHEMA, HOST);
+    pub static ref APP_URL: String = format!("{}/{}/api/app", SCHEMA, HOST);
+    pub static ref VIEW_URL: String = format!("{}/{}/api/view", SCHEMA, HOST);
+    pub static ref DOC_URL: String = format!("{}/{}/api/doc", SCHEMA, HOST);
 
-    pub static ref WS_ADDR: String = format!("ws://localhost:8000/ws");
+    pub static ref WS_ADDR: String = format!("ws://{}/ws", HOST);
 }

+ 4 - 0
rust-lib/flowy-user/src/errors.rs

@@ -109,6 +109,10 @@ impl std::convert::From<::r2d2::Error> for UserError {
     fn from(error: r2d2::Error) -> Self { UserError::internal().context(error) }
 }
 
+impl std::convert::From<flowy_ws::errors::WsError> for UserError {
+    fn from(error: flowy_ws::errors::WsError) -> Self { UserError::internal().context(error) }
+}
+
 // use diesel::result::{Error, DatabaseErrorKind};
 // use flowy_sqlite::ErrorKind;
 impl std::convert::From<flowy_sqlite::Error> for UserError {

+ 21 - 19
rust-lib/flowy-user/src/services/user/user_session.rs

@@ -18,10 +18,10 @@ use flowy_database::{
 };
 use flowy_infra::kv::KV;
 use flowy_sqlite::ConnectionPool;
-use flowy_ws::WsController;
+use flowy_ws::{WsController, WsMessage, WsMessageHandler};
 use parking_lot::RwLock;
 use serde::{Deserialize, Serialize};
-use std::sync::Arc;
+use std::{sync::Arc, time::Duration};
 
 pub struct UserSessionConfig {
     root_dir: String,
@@ -47,7 +47,7 @@ pub struct UserSession {
     #[allow(dead_code)]
     server: Server,
     session: RwLock<Option<Session>>,
-    ws: RwLock<WsController>,
+    ws_controller: RwLock<WsController>,
     status_callback: SessionStatusCallback,
 }
 
@@ -55,13 +55,13 @@ impl UserSession {
     pub fn new(config: UserSessionConfig, status_callback: SessionStatusCallback) -> Self {
         let db = UserDB::new(&config.root_dir);
         let server = construct_user_server();
-        let ws = RwLock::new(WsController::new());
+        let ws_controller = RwLock::new(WsController::new());
         let user_session = Self {
             database: db,
             config,
             server,
             session: RwLock::new(None),
-            ws,
+            ws_controller,
             status_callback,
         };
         user_session
@@ -172,6 +172,21 @@ impl UserSession {
     pub fn user_id(&self) -> Result<String, UserError> { Ok(self.get_session()?.user_id) }
 
     pub fn token(&self) -> Result<String, UserError> { Ok(self.get_session()?.token) }
+
+    pub fn add_ws_msg_handler(&self, handler: Arc<dyn WsMessageHandler>) -> Result<(), UserError> {
+        let _ = self.ws_controller.write().add_handler(handler)?;
+        Ok(())
+    }
+
+    pub fn send_ws_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), UserError> {
+        match self.ws_controller.try_read_for(Duration::from_millis(300)) {
+            None => Err(UserError::internal().context("Send ws message timeout")),
+            Some(guard) => {
+                let _ = guard.send_msg(msg)?;
+                Ok(())
+            },
+        }
+    }
 }
 
 impl UserSession {
@@ -263,20 +278,7 @@ impl UserSession {
 
     fn start_ws_connection(&self, token: &str) -> Result<(), UserError> {
         let addr = format!("{}/{}", flowy_net::config::WS_ADDR.as_str(), token);
-        log::debug!("🐴 Try to connect: {}", &addr);
-        let (conn, handlers) = self.ws.write().make_connect(addr);
-        tokio::spawn(async {
-            match conn.await {
-                Ok(_) => {
-                    log::debug!("🐴 ws connect success");
-                    let _ = handlers.await;
-                },
-                Err(e) => {
-                    // TODO: retry?
-                    log::error!("ws connect failed: {}", e);
-                },
-            }
-        });
+        let _ = self.ws_controller.write().connect(addr);
         Ok(())
     }
 }

+ 1 - 1
rust-lib/flowy-ws/Flowy.toml

@@ -1,2 +1,2 @@
-proto_crates = ["src/errors.rs"]
+proto_crates = ["src/errors.rs", "src/msg.rs"]
 event_files = []

+ 207 - 0
rust-lib/flowy-ws/src/connect.rs

@@ -0,0 +1,207 @@
+use crate::{errors::WsError, MsgReceiver, MsgSender, WsMessage};
+use flowy_net::errors::ServerError;
+use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
+use futures_core::{future::BoxFuture, ready, Stream};
+use futures_util::{
+    future,
+    future::{Either, Select},
+    pin_mut,
+    FutureExt,
+    StreamExt,
+};
+use pin_project::pin_project;
+use std::{
+    collections::HashMap,
+    future::Future,
+    pin::Pin,
+    sync::Arc,
+    task::{Context, Poll},
+};
+use tokio::{net::TcpStream, task::JoinHandle};
+use tokio_tungstenite::{
+    connect_async,
+    tungstenite::{handshake::client::Response, http::StatusCode, Error, Message},
+    MaybeTlsStream,
+    WebSocketStream,
+};
+
+#[pin_project]
+pub struct WsConnection {
+    msg_tx: Option<MsgSender>,
+    ws_rx: Option<MsgReceiver>,
+    #[pin]
+    fut: BoxFuture<'static, Result<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response), Error>>,
+}
+
+impl WsConnection {
+    pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, addr: String) -> Self {
+        WsConnection {
+            msg_tx: Some(msg_tx),
+            ws_rx: Some(ws_rx),
+            fut: Box::pin(async move { connect_async(&addr).await }),
+        }
+    }
+}
+
+impl Future for WsConnection {
+    type Output = Result<WsStream, ServerError>;
+    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+        // [[pin]]
+        // poll async function.  The following methods not work.
+        // 1.
+        // let f = connect_async("");
+        // pin_mut!(f);
+        // ready!(Pin::new(&mut a).poll(cx))
+        //
+        // 2.ready!(Pin::new(&mut Box::pin(connect_async(""))).poll(cx))
+        //
+        // An async method calls poll multiple times and might return to the executor. A
+        // single poll call can only return to the executor once and will get
+        // resumed through another poll invocation. the connect_async call multiple time
+        // from the beginning. So I use fut to hold the future and continue to
+        // poll it. (Fix me if i was wrong)
+        loop {
+            return match ready!(self.as_mut().project().fut.poll(cx)) {
+                Ok((stream, _)) => {
+                    log::debug!("🐴 ws connect success");
+                    let (msg_tx, ws_rx) = (self.msg_tx.take().unwrap(), self.ws_rx.take().unwrap());
+                    Poll::Ready(Ok(WsStream::new(msg_tx, ws_rx, stream)))
+                },
+                Err(error) => Poll::Ready(Err(error_to_flowy_response(error))),
+            };
+        }
+    }
+}
+
+#[pin_project]
+pub struct WsStream {
+    msg_tx: MsgSender,
+    #[pin]
+    fut: Option<(BoxFuture<'static, ()>, BoxFuture<'static, ()>)>,
+}
+
+impl WsStream {
+    pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
+        let (ws_write, ws_read) = stream.split();
+        let to_ws = ws_rx.map(Ok).forward(ws_write);
+        let from_ws = ws_read.for_each(|message| async {
+            // handle_new_message(msg_tx.clone(), message)
+        });
+        // pin_mut!(to_ws, from_ws);
+        Self {
+            msg_tx,
+            fut: Some((
+                Box::pin(async move {
+                    let _ = from_ws.await;
+                }),
+                Box::pin(async move {
+                    let _ = to_ws.await;
+                }),
+            )),
+        }
+    }
+}
+
+impl Future for WsStream {
+    type Output = ();
+    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+        let (mut a, mut b) = self.fut.take().unwrap();
+        match a.poll_unpin(cx) {
+            Poll::Ready(x) => Poll::Ready(()),
+            Poll::Pending => match b.poll_unpin(cx) {
+                Poll::Ready(x) => Poll::Ready(()),
+                Poll::Pending => {
+                    // self.fut = Some((a, b));
+                    Poll::Pending
+                },
+            },
+        }
+    }
+}
+
+// pub struct WsStream {
+//     msg_tx: Option<MsgSender>,
+//     ws_rx: Option<MsgReceiver>,
+//     stream: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
+// }
+//
+// impl WsStream {
+//     pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, stream:
+// WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {         Self {
+//             msg_tx: Some(msg_tx),
+//             ws_rx: Some(ws_rx),
+//             stream: Some(stream),
+//         }
+//     }
+//
+//     pub fn start(mut self) -> JoinHandle<()> {
+//         let (msg_tx, ws_rx) = (self.msg_tx.take().unwrap(),
+// self.ws_rx.take().unwrap());         let (ws_write, ws_read) =
+// self.stream.take().unwrap().split();         tokio::spawn(async move {
+//             let to_ws = ws_rx.map(Ok).forward(ws_write);
+//             let from_ws = ws_read.for_each(|message| async {
+// handle_new_message(msg_tx.clone(), message) });             pin_mut!(to_ws,
+// from_ws);
+//
+//             match future::select(to_ws, from_ws).await {
+//                 Either::Left(_l) => {
+//                     log::info!("ws left");
+//                 },
+//                 Either::Right(_r) => {
+//                     log::info!("ws right");
+//                 },
+//             }
+//         })
+//     }
+// }
+//
+// impl Future for WsStream {
+//     type Output = ();
+//     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) ->
+// Poll<Self::Output> {         let (msg_tx, ws_rx) =
+// (self.msg_tx.take().unwrap(), self.ws_rx.take().unwrap());         let
+// (ws_write, ws_read) = self.stream.take().unwrap().split();         let to_ws
+// = ws_rx.map(Ok).forward(ws_write);         let from_ws =
+// ws_read.for_each(|message| async { handle_new_message(msg_tx.clone(),
+// message) });         pin_mut!(to_ws, from_ws);
+//
+//         loop {
+//             match ready!(Pin::new(&mut future::select(to_ws,
+// from_ws)).poll(cx)) {                 Either::Left(a) => {
+//                     //
+//                     return Poll::Ready(());
+//                 },
+//                 Either::Right(b) => {
+//                     //
+//                     return Poll::Ready(());
+//                 },
+//             }
+//         }
+//     }
+// }
+
+fn handle_new_message(tx: MsgSender, message: Result<Message, Error>) {
+    match message {
+        Ok(Message::Binary(bytes)) => match tx.unbounded_send(Message::Binary(bytes)) {
+            Ok(_) => {},
+            Err(e) => log::error!("tx send error: {:?}", e),
+        },
+        Ok(_) => {},
+        Err(e) => log::error!("ws read error: {:?}", e),
+    }
+}
+
+fn error_to_flowy_response(error: tokio_tungstenite::tungstenite::Error) -> ServerError {
+    let error = match error {
+        Error::Http(response) => {
+            if response.status() == StatusCode::UNAUTHORIZED {
+                ServerError::unauthorized()
+            } else {
+                ServerError::internal().context(response)
+            }
+        },
+        _ => ServerError::internal().context(error),
+    };
+
+    error
+}

+ 9 - 1
rust-lib/flowy-ws/src/errors.rs

@@ -36,11 +36,15 @@ impl WsError {
     }
 
     static_user_error!(internal, ErrorCode::InternalError);
+    static_user_error!(duplicate_source, ErrorCode::DuplicateSource);
+    static_user_error!(unsupported_message, ErrorCode::UnsupportedMessage);
 }
 
 #[derive(Debug, Clone, ProtoBuf_Enum, Display, PartialEq, Eq)]
 pub enum ErrorCode {
-    InternalError = 0,
+    InternalError      = 0,
+    DuplicateSource    = 1,
+    UnsupportedMessage = 2,
 }
 
 impl std::default::Default for ErrorCode {
@@ -51,6 +55,10 @@ impl std::convert::From<url::ParseError> for WsError {
     fn from(error: ParseError) -> Self { WsError::internal().context(error) }
 }
 
+impl std::convert::From<protobuf::ProtobufError> for WsError {
+    fn from(error: protobuf::ProtobufError) -> Self { WsError::internal().context(error) }
+}
+
 impl std::convert::From<futures_channel::mpsc::TrySendError<Message>> for WsError {
     fn from(error: TrySendError<Message>) -> Self { WsError::internal().context(error) }
 }

+ 3 - 0
rust-lib/flowy-ws/src/lib.rs

@@ -1,5 +1,8 @@
+mod connect;
 pub mod errors;
+mod msg;
 pub mod protobuf;
 mod ws;
 
+pub use msg::*;
 pub use ws::*;

+ 38 - 0
rust-lib/flowy-ws/src/msg.rs

@@ -0,0 +1,38 @@
+use bytes::Bytes;
+use flowy_derive::ProtoBuf;
+use std::convert::{TryFrom, TryInto};
+use tokio_tungstenite::tungstenite::Message;
+
+#[derive(ProtoBuf, Debug, Clone, Default)]
+pub struct WsMessage {
+    #[pb(index = 1)]
+    pub source: String,
+
+    #[pb(index = 2)]
+    pub data: Vec<u8>,
+}
+
+impl std::convert::Into<Message> for WsMessage {
+    fn into(self) -> Message {
+        let result: Result<Bytes, ::protobuf::ProtobufError> = self.try_into();
+        match result {
+            Ok(bytes) => Message::Binary(bytes.to_vec()),
+            Err(e) => {
+                log::error!("WsMessage serialize error: {:?}", e);
+                Message::Binary(vec![])
+            },
+        }
+    }
+}
+
+impl std::convert::From<Message> for WsMessage {
+    fn from(value: Message) -> Self {
+        match value {
+            Message::Binary(bytes) => WsMessage::try_from(Bytes::from(bytes)).unwrap(),
+            _ => {
+                log::error!("WsMessage deserialize failed. Unsupported message");
+                WsMessage::default()
+            },
+        }
+    }
+}

+ 24 - 13
rust-lib/flowy-ws/src/protobuf/model/errors.rs

@@ -216,6 +216,8 @@ impl ::protobuf::reflect::ProtobufValue for WsError {
 #[derive(Clone,PartialEq,Eq,Debug,Hash)]
 pub enum ErrorCode {
     InternalError = 0,
+    DuplicateSource = 1,
+    UnsupportedMessage = 2,
 }
 
 impl ::protobuf::ProtobufEnum for ErrorCode {
@@ -226,6 +228,8 @@ impl ::protobuf::ProtobufEnum for ErrorCode {
     fn from_i32(value: i32) -> ::std::option::Option<ErrorCode> {
         match value {
             0 => ::std::option::Option::Some(ErrorCode::InternalError),
+            1 => ::std::option::Option::Some(ErrorCode::DuplicateSource),
+            2 => ::std::option::Option::Some(ErrorCode::UnsupportedMessage),
             _ => ::std::option::Option::None
         }
     }
@@ -233,6 +237,8 @@ impl ::protobuf::ProtobufEnum for ErrorCode {
     fn values() -> &'static [Self] {
         static values: &'static [ErrorCode] = &[
             ErrorCode::InternalError,
+            ErrorCode::DuplicateSource,
+            ErrorCode::UnsupportedMessage,
         ];
         values
     }
@@ -262,19 +268,24 @@ impl ::protobuf::reflect::ProtobufValue for ErrorCode {
 
 static file_descriptor_proto_data: &'static [u8] = b"\
     \n\x0cerrors.proto\";\n\x07WsError\x12\x1e\n\x04code\x18\x01\x20\x01(\
-    \x0e2\n.ErrorCodeR\x04code\x12\x10\n\x03msg\x18\x02\x20\x01(\tR\x03msg*\
-    \x1e\n\tErrorCode\x12\x11\n\rInternalError\x10\0J\xd9\x01\n\x06\x12\x04\
-    \0\0\x08\x01\n\x08\n\x01\x0c\x12\x03\0\0\x12\n\n\n\x02\x04\0\x12\x04\x02\
-    \0\x05\x01\n\n\n\x03\x04\0\x01\x12\x03\x02\x08\x0f\n\x0b\n\x04\x04\0\x02\
-    \0\x12\x03\x03\x04\x17\n\x0c\n\x05\x04\0\x02\0\x06\x12\x03\x03\x04\r\n\
-    \x0c\n\x05\x04\0\x02\0\x01\x12\x03\x03\x0e\x12\n\x0c\n\x05\x04\0\x02\0\
-    \x03\x12\x03\x03\x15\x16\n\x0b\n\x04\x04\0\x02\x01\x12\x03\x04\x04\x13\n\
-    \x0c\n\x05\x04\0\x02\x01\x05\x12\x03\x04\x04\n\n\x0c\n\x05\x04\0\x02\x01\
-    \x01\x12\x03\x04\x0b\x0e\n\x0c\n\x05\x04\0\x02\x01\x03\x12\x03\x04\x11\
-    \x12\n\n\n\x02\x05\0\x12\x04\x06\0\x08\x01\n\n\n\x03\x05\0\x01\x12\x03\
-    \x06\x05\x0e\n\x0b\n\x04\x05\0\x02\0\x12\x03\x07\x04\x16\n\x0c\n\x05\x05\
-    \0\x02\0\x01\x12\x03\x07\x04\x11\n\x0c\n\x05\x05\0\x02\0\x02\x12\x03\x07\
-    \x14\x15b\x06proto3\
+    \x0e2\n.ErrorCodeR\x04code\x12\x10\n\x03msg\x18\x02\x20\x01(\tR\x03msg*K\
+    \n\tErrorCode\x12\x11\n\rInternalError\x10\0\x12\x13\n\x0fDuplicateSourc\
+    e\x10\x01\x12\x16\n\x12UnsupportedMessage\x10\x02J\xab\x02\n\x06\x12\x04\
+    \0\0\n\x01\n\x08\n\x01\x0c\x12\x03\0\0\x12\n\n\n\x02\x04\0\x12\x04\x02\0\
+    \x05\x01\n\n\n\x03\x04\0\x01\x12\x03\x02\x08\x0f\n\x0b\n\x04\x04\0\x02\0\
+    \x12\x03\x03\x04\x17\n\x0c\n\x05\x04\0\x02\0\x06\x12\x03\x03\x04\r\n\x0c\
+    \n\x05\x04\0\x02\0\x01\x12\x03\x03\x0e\x12\n\x0c\n\x05\x04\0\x02\0\x03\
+    \x12\x03\x03\x15\x16\n\x0b\n\x04\x04\0\x02\x01\x12\x03\x04\x04\x13\n\x0c\
+    \n\x05\x04\0\x02\x01\x05\x12\x03\x04\x04\n\n\x0c\n\x05\x04\0\x02\x01\x01\
+    \x12\x03\x04\x0b\x0e\n\x0c\n\x05\x04\0\x02\x01\x03\x12\x03\x04\x11\x12\n\
+    \n\n\x02\x05\0\x12\x04\x06\0\n\x01\n\n\n\x03\x05\0\x01\x12\x03\x06\x05\
+    \x0e\n\x0b\n\x04\x05\0\x02\0\x12\x03\x07\x04\x16\n\x0c\n\x05\x05\0\x02\0\
+    \x01\x12\x03\x07\x04\x11\n\x0c\n\x05\x05\0\x02\0\x02\x12\x03\x07\x14\x15\
+    \n\x0b\n\x04\x05\0\x02\x01\x12\x03\x08\x04\x18\n\x0c\n\x05\x05\0\x02\x01\
+    \x01\x12\x03\x08\x04\x13\n\x0c\n\x05\x05\0\x02\x01\x02\x12\x03\x08\x16\
+    \x17\n\x0b\n\x04\x05\0\x02\x02\x12\x03\t\x04\x1b\n\x0c\n\x05\x05\0\x02\
+    \x02\x01\x12\x03\t\x04\x16\n\x0c\n\x05\x05\0\x02\x02\x02\x12\x03\t\x19\
+    \x1ab\x06proto3\
 ";
 
 static file_descriptor_proto_lazy: ::protobuf::rt::LazyV2<::protobuf::descriptor::FileDescriptorProto> = ::protobuf::rt::LazyV2::INIT;

+ 3 - 0
rust-lib/flowy-ws/src/protobuf/model/mod.rs

@@ -2,3 +2,6 @@
 
 mod errors; 
 pub use errors::*; 
+
+mod msg; 
+pub use msg::*; 

+ 250 - 0
rust-lib/flowy-ws/src/protobuf/model/msg.rs

@@ -0,0 +1,250 @@
+// This file is generated by rust-protobuf 2.22.1. Do not edit
+// @generated
+
+// https://github.com/rust-lang/rust-clippy/issues/702
+#![allow(unknown_lints)]
+#![allow(clippy::all)]
+
+#![allow(unused_attributes)]
+#![cfg_attr(rustfmt, rustfmt::skip)]
+
+#![allow(box_pointers)]
+#![allow(dead_code)]
+#![allow(missing_docs)]
+#![allow(non_camel_case_types)]
+#![allow(non_snake_case)]
+#![allow(non_upper_case_globals)]
+#![allow(trivial_casts)]
+#![allow(unused_imports)]
+#![allow(unused_results)]
+//! Generated file from `msg.proto`
+
+/// Generated files are compatible only with the same version
+/// of protobuf runtime.
+// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_22_1;
+
+#[derive(PartialEq,Clone,Default)]
+pub struct WsMessage {
+    // message fields
+    pub source: ::std::string::String,
+    pub data: ::std::vec::Vec<u8>,
+    // special fields
+    pub unknown_fields: ::protobuf::UnknownFields,
+    pub cached_size: ::protobuf::CachedSize,
+}
+
+impl<'a> ::std::default::Default for &'a WsMessage {
+    fn default() -> &'a WsMessage {
+        <WsMessage as ::protobuf::Message>::default_instance()
+    }
+}
+
+impl WsMessage {
+    pub fn new() -> WsMessage {
+        ::std::default::Default::default()
+    }
+
+    // string source = 1;
+
+
+    pub fn get_source(&self) -> &str {
+        &self.source
+    }
+    pub fn clear_source(&mut self) {
+        self.source.clear();
+    }
+
+    // Param is passed by value, moved
+    pub fn set_source(&mut self, v: ::std::string::String) {
+        self.source = v;
+    }
+
+    // Mutable pointer to the field.
+    // If field is not initialized, it is initialized with default value first.
+    pub fn mut_source(&mut self) -> &mut ::std::string::String {
+        &mut self.source
+    }
+
+    // Take field
+    pub fn take_source(&mut self) -> ::std::string::String {
+        ::std::mem::replace(&mut self.source, ::std::string::String::new())
+    }
+
+    // bytes data = 2;
+
+
+    pub fn get_data(&self) -> &[u8] {
+        &self.data
+    }
+    pub fn clear_data(&mut self) {
+        self.data.clear();
+    }
+
+    // Param is passed by value, moved
+    pub fn set_data(&mut self, v: ::std::vec::Vec<u8>) {
+        self.data = v;
+    }
+
+    // Mutable pointer to the field.
+    // If field is not initialized, it is initialized with default value first.
+    pub fn mut_data(&mut self) -> &mut ::std::vec::Vec<u8> {
+        &mut self.data
+    }
+
+    // Take field
+    pub fn take_data(&mut self) -> ::std::vec::Vec<u8> {
+        ::std::mem::replace(&mut self.data, ::std::vec::Vec::new())
+    }
+}
+
+impl ::protobuf::Message for WsMessage {
+    fn is_initialized(&self) -> bool {
+        true
+    }
+
+    fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::ProtobufResult<()> {
+        while !is.eof()? {
+            let (field_number, wire_type) = is.read_tag_unpack()?;
+            match field_number {
+                1 => {
+                    ::protobuf::rt::read_singular_proto3_string_into(wire_type, is, &mut self.source)?;
+                },
+                2 => {
+                    ::protobuf::rt::read_singular_proto3_bytes_into(wire_type, is, &mut self.data)?;
+                },
+                _ => {
+                    ::protobuf::rt::read_unknown_or_skip_group(field_number, wire_type, is, self.mut_unknown_fields())?;
+                },
+            };
+        }
+        ::std::result::Result::Ok(())
+    }
+
+    // Compute sizes of nested messages
+    #[allow(unused_variables)]
+    fn compute_size(&self) -> u32 {
+        let mut my_size = 0;
+        if !self.source.is_empty() {
+            my_size += ::protobuf::rt::string_size(1, &self.source);
+        }
+        if !self.data.is_empty() {
+            my_size += ::protobuf::rt::bytes_size(2, &self.data);
+        }
+        my_size += ::protobuf::rt::unknown_fields_size(self.get_unknown_fields());
+        self.cached_size.set(my_size);
+        my_size
+    }
+
+    fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::ProtobufResult<()> {
+        if !self.source.is_empty() {
+            os.write_string(1, &self.source)?;
+        }
+        if !self.data.is_empty() {
+            os.write_bytes(2, &self.data)?;
+        }
+        os.write_unknown_fields(self.get_unknown_fields())?;
+        ::std::result::Result::Ok(())
+    }
+
+    fn get_cached_size(&self) -> u32 {
+        self.cached_size.get()
+    }
+
+    fn get_unknown_fields(&self) -> &::protobuf::UnknownFields {
+        &self.unknown_fields
+    }
+
+    fn mut_unknown_fields(&mut self) -> &mut ::protobuf::UnknownFields {
+        &mut self.unknown_fields
+    }
+
+    fn as_any(&self) -> &dyn (::std::any::Any) {
+        self as &dyn (::std::any::Any)
+    }
+    fn as_any_mut(&mut self) -> &mut dyn (::std::any::Any) {
+        self as &mut dyn (::std::any::Any)
+    }
+    fn into_any(self: ::std::boxed::Box<Self>) -> ::std::boxed::Box<dyn (::std::any::Any)> {
+        self
+    }
+
+    fn descriptor(&self) -> &'static ::protobuf::reflect::MessageDescriptor {
+        Self::descriptor_static()
+    }
+
+    fn new() -> WsMessage {
+        WsMessage::new()
+    }
+
+    fn descriptor_static() -> &'static ::protobuf::reflect::MessageDescriptor {
+        static descriptor: ::protobuf::rt::LazyV2<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::LazyV2::INIT;
+        descriptor.get(|| {
+            let mut fields = ::std::vec::Vec::new();
+            fields.push(::protobuf::reflect::accessor::make_simple_field_accessor::<_, ::protobuf::types::ProtobufTypeString>(
+                "source",
+                |m: &WsMessage| { &m.source },
+                |m: &mut WsMessage| { &mut m.source },
+            ));
+            fields.push(::protobuf::reflect::accessor::make_simple_field_accessor::<_, ::protobuf::types::ProtobufTypeBytes>(
+                "data",
+                |m: &WsMessage| { &m.data },
+                |m: &mut WsMessage| { &mut m.data },
+            ));
+            ::protobuf::reflect::MessageDescriptor::new_pb_name::<WsMessage>(
+                "WsMessage",
+                fields,
+                file_descriptor_proto()
+            )
+        })
+    }
+
+    fn default_instance() -> &'static WsMessage {
+        static instance: ::protobuf::rt::LazyV2<WsMessage> = ::protobuf::rt::LazyV2::INIT;
+        instance.get(WsMessage::new)
+    }
+}
+
+impl ::protobuf::Clear for WsMessage {
+    fn clear(&mut self) {
+        self.source.clear();
+        self.data.clear();
+        self.unknown_fields.clear();
+    }
+}
+
+impl ::std::fmt::Debug for WsMessage {
+    fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
+        ::protobuf::text_format::fmt(self, f)
+    }
+}
+
+impl ::protobuf::reflect::ProtobufValue for WsMessage {
+    fn as_ref(&self) -> ::protobuf::reflect::ReflectValueRef {
+        ::protobuf::reflect::ReflectValueRef::Message(self)
+    }
+}
+
+static file_descriptor_proto_data: &'static [u8] = b"\
+    \n\tmsg.proto\"7\n\tWsMessage\x12\x16\n\x06source\x18\x01\x20\x01(\tR\
+    \x06source\x12\x12\n\x04data\x18\x02\x20\x01(\x0cR\x04dataJ\x98\x01\n\
+    \x06\x12\x04\0\0\x05\x01\n\x08\n\x01\x0c\x12\x03\0\0\x12\n\n\n\x02\x04\0\
+    \x12\x04\x02\0\x05\x01\n\n\n\x03\x04\0\x01\x12\x03\x02\x08\x11\n\x0b\n\
+    \x04\x04\0\x02\0\x12\x03\x03\x04\x16\n\x0c\n\x05\x04\0\x02\0\x05\x12\x03\
+    \x03\x04\n\n\x0c\n\x05\x04\0\x02\0\x01\x12\x03\x03\x0b\x11\n\x0c\n\x05\
+    \x04\0\x02\0\x03\x12\x03\x03\x14\x15\n\x0b\n\x04\x04\0\x02\x01\x12\x03\
+    \x04\x04\x13\n\x0c\n\x05\x04\0\x02\x01\x05\x12\x03\x04\x04\t\n\x0c\n\x05\
+    \x04\0\x02\x01\x01\x12\x03\x04\n\x0e\n\x0c\n\x05\x04\0\x02\x01\x03\x12\
+    \x03\x04\x11\x12b\x06proto3\
+";
+
+static file_descriptor_proto_lazy: ::protobuf::rt::LazyV2<::protobuf::descriptor::FileDescriptorProto> = ::protobuf::rt::LazyV2::INIT;
+
+fn parse_descriptor_proto() -> ::protobuf::descriptor::FileDescriptorProto {
+    ::protobuf::Message::parse_from_bytes(file_descriptor_proto_data).unwrap()
+}
+
+pub fn file_descriptor_proto() -> &'static ::protobuf::descriptor::FileDescriptorProto {
+    file_descriptor_proto_lazy.get(|| {
+        parse_descriptor_proto()
+    })
+}

+ 2 - 0
rust-lib/flowy-ws/src/protobuf/proto/errors.proto

@@ -6,4 +6,6 @@ message WsError {
 }
 enum ErrorCode {
     InternalError = 0;
+    DuplicateSource = 1;
+    UnsupportedMessage = 2;
 }

+ 6 - 0
rust-lib/flowy-ws/src/protobuf/proto/msg.proto

@@ -0,0 +1,6 @@
+syntax = "proto3";
+
+message WsMessage {
+    string source = 1;
+    bytes data = 2;
+}

+ 63 - 137
rust-lib/flowy-ws/src/ws.rs

@@ -1,16 +1,23 @@
-use crate::errors::WsError;
-use flowy_net::{errors::ServerError, response::FlowyResponse};
+use crate::{connect::WsConnection, errors::WsError, WsMessage};
+use flowy_net::errors::ServerError;
 use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
 use futures_core::{future::BoxFuture, ready, Stream};
-use futures_util::{pin_mut, FutureExt, StreamExt};
+use futures_util::{
+    future,
+    future::{Either, Select},
+    pin_mut,
+    FutureExt,
+    StreamExt,
+};
 use pin_project::pin_project;
 use std::{
+    collections::HashMap,
     future::Future,
     pin::Pin,
     sync::Arc,
     task::{Context, Poll},
 };
-use tokio::net::TcpStream;
+use tokio::{net::TcpStream, task::JoinHandle};
 use tokio_tungstenite::{
     connect_async,
     tungstenite::{handshake::client::Response, http::StatusCode, Error, Message},
@@ -21,37 +28,56 @@ use tokio_tungstenite::{
 pub type MsgReceiver = UnboundedReceiver<Message>;
 pub type MsgSender = UnboundedSender<Message>;
 pub trait WsMessageHandler: Sync + Send + 'static {
-    fn can_handle(&self) -> bool;
-    fn receive_message(&self, msg: &Message);
-    fn send_message(&self, sender: Arc<WsSender>);
+    fn source(&self) -> String;
+    fn receive_message(&self, msg: WsMessage);
 }
 
 pub struct WsController {
     sender: Option<Arc<WsSender>>,
-    handlers: Vec<Arc<dyn WsMessageHandler>>,
+    handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
 }
 
 impl WsController {
     pub fn new() -> Self {
         let controller = Self {
             sender: None,
-            handlers: vec![],
+            handlers: HashMap::new(),
         };
-
         controller
     }
 
-    pub fn add_handlers(&mut self, handler: Arc<dyn WsMessageHandler>) { self.handlers.push(handler); }
-
-    #[allow(dead_code)]
-    pub async fn connect(&mut self, addr: String) -> Result<(), ServerError> {
-        let (conn, handlers) = self.make_connect(addr);
-        let _ = conn.await?;
-        let _ = tokio::spawn(handlers);
+    pub fn add_handler(&mut self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
+        let source = handler.source();
+        if self.handlers.contains_key(&source) {
+            return Err(WsError::duplicate_source());
+        }
+        self.handlers.insert(source, handler);
         Ok(())
     }
 
-    pub fn make_connect(&mut self, addr: String) -> (WsConnection, WsHandlers) {
+    pub fn connect(&mut self, addr: String) -> Result<JoinHandle<()>, ServerError> {
+        log::debug!("🐴 Try to connect: {}", &addr);
+        let (connection, handlers) = self.make_connect(addr);
+        Ok(tokio::spawn(async {
+            tokio::select! {
+                result = connection => {
+                    match result {
+                        Ok(stream) => {
+                            tokio::spawn(stream).await;
+                            // stream.start().await;
+                        },
+                        Err(e) => {
+                            // TODO: retry?
+                            log::error!("ws connect failed {:?}", e);
+                        }
+                    }
+                },
+                result = handlers => log::debug!("handlers completed {:?}", result),
+            };
+        }))
+    }
+
+    fn make_connect(&mut self, addr: String) -> (WsConnection, WsHandlers) {
         //                Stream                             User
         //               ┌───────────────┐                 ┌──────────────┐
         // ┌──────┐      │  ┌─────────┐  │    ┌────────┐   │  ┌────────┐  │
@@ -64,16 +90,15 @@ impl WsController {
         //               └───────────────┘                 └──────────────┘
         let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
         let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
-        let sender = Arc::new(WsSender::new(ws_tx));
         let handlers = self.handlers.clone();
-        self.sender = Some(sender.clone());
+        self.sender = Some(Arc::new(WsSender::new(ws_tx)));
         (WsConnection::new(msg_tx, ws_rx, addr), WsHandlers::new(handlers, msg_rx))
     }
 
-    pub fn send_message(&self, msg: Message) -> Result<(), WsError> {
-        match &self.sender {
-            None => panic!(),
-            Some(conn) => conn.send(msg),
+    pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
+        match self.sender.as_ref() {
+            None => Err(WsError::internal().context("Should call make_connect first")),
+            Some(sender) => sender.send(msg.into()),
         }
     }
 }
@@ -82,11 +107,11 @@ impl WsController {
 pub struct WsHandlers {
     #[pin]
     msg_rx: MsgReceiver,
-    handlers: Vec<Arc<dyn WsMessageHandler>>,
+    handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
 }
 
 impl WsHandlers {
-    fn new(handlers: Vec<Arc<dyn WsMessageHandler>>, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
+    fn new(handlers: HashMap<String, Arc<dyn WsMessageHandler>>, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
 }
 
 impl Future for WsHandlers {
@@ -94,130 +119,31 @@ impl Future for WsHandlers {
     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
         loop {
             match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
-                None => return Poll::Ready(()),
-                Some(message) => self.handlers.iter().for_each(|handler| {
-                    handler.receive_message(&message);
-                }),
-            }
-        }
-    }
-}
-
-#[pin_project]
-pub struct WsConnection {
-    msg_tx: Option<MsgSender>,
-    ws_rx: Option<MsgReceiver>,
-    #[pin]
-    fut: BoxFuture<'static, Result<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response), Error>>,
-}
-
-impl WsConnection {
-    pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, addr: String) -> Self {
-        WsConnection {
-            msg_tx: Some(msg_tx),
-            ws_rx: Some(ws_rx),
-            fut: Box::pin(async move { connect_async(&addr).await }),
-        }
-    }
-}
-
-impl Future for WsConnection {
-    type Output = Result<(), ServerError>;
-    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
-        // [[pin]]
-        // poll async function.  The following methods not work.
-        // 1.
-        // let f = connect_async("");
-        // pin_mut!(f);
-        // ready!(Pin::new(&mut a).poll(cx))
-        //
-        // 2.ready!(Pin::new(&mut Box::pin(connect_async(""))).poll(cx))
-        //
-        // An async method calls poll multiple times and might return to the executor. A
-        // single poll call can only return to the executor once and will get
-        // resumed through another poll invocation. the connect_async call multiple time
-        // from the beginning. So I use fut to hold the future and continue to
-        // poll it. (Fix me if i was wrong)
-
-        loop {
-            return match ready!(self.as_mut().project().fut.poll(cx)) {
-                Ok((stream, _)) => {
-                    let mut ws_stream = WsStream {
-                        msg_tx: self.msg_tx.take(),
-                        ws_rx: self.ws_rx.take(),
-                        stream: Some(stream),
-                    };
-                    match Pin::new(&mut ws_stream).poll(cx) {
-                        Poll::Ready(_) => Poll::Ready(Ok(())),
-                        Poll::Pending => Poll::Pending,
-                    }
+                None => {
+                    // log::debug!("🐴 ws handler done");
+                    return Poll::Pending;
                 },
-                Err(error) => Poll::Ready(Err(error_to_flowy_response(error))),
-            };
-        }
-    }
-}
-
-fn error_to_flowy_response(error: tokio_tungstenite::tungstenite::Error) -> ServerError {
-    let error = match error {
-        Error::Http(response) => {
-            if response.status() == StatusCode::UNAUTHORIZED {
-                ServerError::unauthorized()
-            } else {
-                ServerError::internal().context(response)
-            }
-        },
-        _ => ServerError::internal().context(error),
-    };
-
-    error
-}
-
-struct WsStream {
-    msg_tx: Option<MsgSender>,
-    ws_rx: Option<MsgReceiver>,
-    stream: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
-}
-
-impl Future for WsStream {
-    type Output = ();
-    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
-        let (tx, rx) = (self.msg_tx.take().unwrap(), self.ws_rx.take().unwrap());
-        let (ws_write, ws_read) = self.stream.take().unwrap().split();
-        let to_ws = rx.map(Ok).forward(ws_write);
-        let from_ws = ws_read.for_each(|message| async {
-            match message {
-                Ok(message) => {
-                    match tx.unbounded_send(message) {
-                        Ok(_) => {},
-                        Err(e) => log::error!("tx send error: {:?}", e),
-                    };
+                Some(message) => {
+                    let message = WsMessage::from(message);
+                    match self.handlers.get(&message.source) {
+                        None => log::error!("Can't find any handler for message: {:?}", message),
+                        Some(handler) => handler.receive_message(message.clone()),
+                    }
                 },
-                Err(e) => log::error!("ws read error: {:?}", e),
             }
-        });
-
-        pin_mut!(to_ws, from_ws);
-        log::trace!("🐴 ws start poll stream");
-        match to_ws.poll_unpin(cx) {
-            Poll::Ready(_) => Poll::Ready(()),
-            Poll::Pending => match from_ws.poll_unpin(cx) {
-                Poll::Ready(_) => Poll::Ready(()),
-                Poll::Pending => Poll::Pending,
-            },
         }
     }
 }
 
-pub struct WsSender {
+struct WsSender {
     ws_tx: MsgSender,
 }
 
 impl WsSender {
     pub fn new(ws_tx: MsgSender) -> Self { Self { ws_tx } }
 
-    pub fn send(&self, msg: Message) -> Result<(), WsError> {
-        let _ = self.ws_tx.unbounded_send(msg)?;
+    pub fn send(&self, msg: WsMessage) -> Result<(), WsError> {
+        let _ = self.ws_tx.unbounded_send(msg.into()).map_err(|e| WsError::internal().context(e))?;
         Ok(())
     }
 }