Browse Source

stop local websocket before start connecting

appflowy 3 years ago
parent
commit
3c819ead49

+ 1 - 1
frontend/rust-lib/flowy-document/src/core/edit/queue.rs

@@ -15,7 +15,7 @@ use lib_ot::{
     core::{Interval, OperationTransformable},
     rich_text::{RichTextAttribute, RichTextDelta},
 };
-use std::{cell::Cell, sync::Arc};
+use std::sync::Arc;
 use tokio::sync::{oneshot, RwLock};
 
 // The EditorCommandQueue executes each command that will alter the document in

+ 63 - 35
frontend/rust-lib/flowy-net/src/ws/local/local_ws.rs

@@ -15,9 +15,13 @@ use tokio::sync::{broadcast, broadcast::Receiver, mpsc, mpsc::UnboundedReceiver}
 pub struct LocalWebSocket {
     receivers: Arc<DashMap<WSModule, Arc<dyn WSMessageReceiver>>>,
     state_sender: broadcast::Sender<WSConnectState>,
+    // LocalWSSender uses the mpsc::channel sender to simulate the web socket. It spawns a receiver that uses the
+    // LocalDocumentServer  to handle the message. The server will send the WebSocketRawMessage messages that will
+    // be handled by the WebSocketRawMessage receivers.
     ws_sender: LocalWSSender,
-    server: Arc<LocalDocumentServer>,
-    server_rx: RwLock<Option<UnboundedReceiver<WebSocketRawMessage>>>,
+    local_server: Arc<LocalDocumentServer>,
+    local_server_rx: RwLock<Option<UnboundedReceiver<WebSocketRawMessage>>>,
+    local_server_stop_tx: RwLock<Option<mpsc::Sender<()>>>,
     user_id: Arc<RwLock<Option<String>>>,
 }
 
@@ -28,67 +32,91 @@ impl std::default::Default for LocalWebSocket {
         let receivers = Arc::new(DashMap::new());
 
         let (server_tx, server_rx) = mpsc::unbounded_channel();
-        let server = Arc::new(LocalDocumentServer::new(server_tx));
-        let server_rx = RwLock::new(Some(server_rx));
-        let user_token = Arc::new(RwLock::new(None));
+        let local_server = Arc::new(LocalDocumentServer::new(server_tx));
+        let local_server_rx = RwLock::new(Some(server_rx));
+        let local_server_stop_tx = RwLock::new(None);
+        let user_id = Arc::new(RwLock::new(None));
 
         LocalWebSocket {
             receivers,
             state_sender,
             ws_sender,
-            server,
-            server_rx,
-            user_id: user_token,
+            local_server,
+            local_server_rx,
+            local_server_stop_tx,
+            user_id,
         }
     }
 }
 
 impl LocalWebSocket {
-    fn spawn_client(&self, _addr: String) {
+    fn restart_ws_receiver(&self) -> mpsc::Receiver<()> {
+        if let Some(stop_tx) = self.local_server_stop_tx.read().clone() {
+            tokio::spawn(async move {
+                let _ = stop_tx.send(()).await;
+            });
+        }
+        let (stop_tx, stop_rx) = mpsc::channel::<()>(1);
+        *self.local_server_stop_tx.write() = Some(stop_tx);
+        stop_rx
+    }
+
+    fn spawn_client_ws_receiver(&self, _addr: String) {
         let mut ws_receiver = self.ws_sender.subscribe();
-        let local_server = self.server.clone();
+        let local_server = self.local_server.clone();
         let user_id = self.user_id.clone();
+        let mut stop_rx = self.restart_ws_receiver();
         tokio::spawn(async move {
             loop {
-                // Polling the web socket message sent by user
-                match ws_receiver.recv().await {
-                    Ok(message) => {
-                        let user_id = user_id.read().clone();
-                        if user_id.is_none() {
-                            continue;
-                        }
-                        let user_id = user_id.unwrap();
-                        let server = local_server.clone();
-                        let fut = || async move {
-                            let bytes = Bytes::from(message.data);
-                            let client_data = DocumentClientWSData::try_from(bytes).map_err(internal_error)?;
-                            let _ = server
-                                .handle_client_data(client_data, user_id)
-                                .await
-                                .map_err(internal_error)?;
-                            Ok::<(), FlowyError>(())
-                        };
-                        match fut().await {
-                            Ok(_) => {},
-                            Err(e) => tracing::error!("[LocalWebSocket] error: {:?}", e),
+                tokio::select! {
+                    result = ws_receiver.recv() => {
+                        match result {
+                            Ok(message) => {
+                                let user_id = user_id.read().clone();
+                                handle_ws_raw_message(user_id, &local_server, message).await;
+                            },
+                            Err(e) => tracing::error!("[LocalWebSocket] error: {}", e),
                         }
+                    }
+                    _ = stop_rx.recv() => {
+                        break
                     },
-                    Err(e) => tracing::error!("[LocalWebSocket] error: {}", e),
                 }
             }
         });
     }
 }
 
+async fn handle_ws_raw_message(
+    user_id: Option<String>,
+    local_server: &Arc<LocalDocumentServer>,
+    message: WebSocketRawMessage,
+) {
+    let f = || async {
+        match user_id {
+            None => Ok(()),
+            Some(user_id) => {
+                let bytes = Bytes::from(message.data);
+                let client_data = DocumentClientWSData::try_from(bytes).map_err(internal_error)?;
+                let _ = local_server.handle_client_data(client_data, user_id).await?;
+                Ok::<(), FlowyError>(())
+            },
+        }
+    };
+    if let Err(e) = f().await {
+        tracing::error!("[LocalWebSocket] error: {:?}", e);
+    }
+}
+
 impl FlowyRawWebSocket for LocalWebSocket {
     fn initialize(&self) -> FutureResult<(), FlowyError> {
-        let mut server_rx = self.server_rx.write().take().expect("Only take once");
+        let mut server_rx = self.local_server_rx.write().take().expect("Only take once");
         let receivers = self.receivers.clone();
         tokio::spawn(async move {
             while let Some(message) = server_rx.recv().await {
                 match receivers.get(&message.module) {
                     None => tracing::error!("Can't find any handler for message: {:?}", message),
-                    Some(handler) => handler.receive_message(message.clone()),
+                    Some(receiver) => receiver.receive_message(message.clone()),
                 }
             }
         });
@@ -97,7 +125,7 @@ impl FlowyRawWebSocket for LocalWebSocket {
 
     fn start_connect(&self, addr: String, user_id: String) -> FutureResult<(), FlowyError> {
         *self.user_id.write() = Some(user_id);
-        self.spawn_client(addr);
+        self.spawn_client_ws_receiver(addr);
         FutureResult::new(async { Ok(()) })
     }