|
@@ -15,9 +15,13 @@ use tokio::sync::{broadcast, broadcast::Receiver, mpsc, mpsc::UnboundedReceiver}
|
|
|
pub struct LocalWebSocket {
|
|
|
receivers: Arc<DashMap<WSModule, Arc<dyn WSMessageReceiver>>>,
|
|
|
state_sender: broadcast::Sender<WSConnectState>,
|
|
|
+ // LocalWSSender uses the mpsc::channel sender to simulate the web socket. It spawns a receiver that uses the
|
|
|
+ // LocalDocumentServer to handle the message. The server will send the WebSocketRawMessage messages that will
|
|
|
+ // be handled by the WebSocketRawMessage receivers.
|
|
|
ws_sender: LocalWSSender,
|
|
|
- server: Arc<LocalDocumentServer>,
|
|
|
- server_rx: RwLock<Option<UnboundedReceiver<WebSocketRawMessage>>>,
|
|
|
+ local_server: Arc<LocalDocumentServer>,
|
|
|
+ local_server_rx: RwLock<Option<UnboundedReceiver<WebSocketRawMessage>>>,
|
|
|
+ local_server_stop_tx: RwLock<Option<mpsc::Sender<()>>>,
|
|
|
user_id: Arc<RwLock<Option<String>>>,
|
|
|
}
|
|
|
|
|
@@ -28,67 +32,91 @@ impl std::default::Default for LocalWebSocket {
|
|
|
let receivers = Arc::new(DashMap::new());
|
|
|
|
|
|
let (server_tx, server_rx) = mpsc::unbounded_channel();
|
|
|
- let server = Arc::new(LocalDocumentServer::new(server_tx));
|
|
|
- let server_rx = RwLock::new(Some(server_rx));
|
|
|
- let user_token = Arc::new(RwLock::new(None));
|
|
|
+ let local_server = Arc::new(LocalDocumentServer::new(server_tx));
|
|
|
+ let local_server_rx = RwLock::new(Some(server_rx));
|
|
|
+ let local_server_stop_tx = RwLock::new(None);
|
|
|
+ let user_id = Arc::new(RwLock::new(None));
|
|
|
|
|
|
LocalWebSocket {
|
|
|
receivers,
|
|
|
state_sender,
|
|
|
ws_sender,
|
|
|
- server,
|
|
|
- server_rx,
|
|
|
- user_id: user_token,
|
|
|
+ local_server,
|
|
|
+ local_server_rx,
|
|
|
+ local_server_stop_tx,
|
|
|
+ user_id,
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
impl LocalWebSocket {
|
|
|
- fn spawn_client(&self, _addr: String) {
|
|
|
+ fn restart_ws_receiver(&self) -> mpsc::Receiver<()> {
|
|
|
+ if let Some(stop_tx) = self.local_server_stop_tx.read().clone() {
|
|
|
+ tokio::spawn(async move {
|
|
|
+ let _ = stop_tx.send(()).await;
|
|
|
+ });
|
|
|
+ }
|
|
|
+ let (stop_tx, stop_rx) = mpsc::channel::<()>(1);
|
|
|
+ *self.local_server_stop_tx.write() = Some(stop_tx);
|
|
|
+ stop_rx
|
|
|
+ }
|
|
|
+
|
|
|
+ fn spawn_client_ws_receiver(&self, _addr: String) {
|
|
|
let mut ws_receiver = self.ws_sender.subscribe();
|
|
|
- let local_server = self.server.clone();
|
|
|
+ let local_server = self.local_server.clone();
|
|
|
let user_id = self.user_id.clone();
|
|
|
+ let mut stop_rx = self.restart_ws_receiver();
|
|
|
tokio::spawn(async move {
|
|
|
loop {
|
|
|
- // Polling the web socket message sent by user
|
|
|
- match ws_receiver.recv().await {
|
|
|
- Ok(message) => {
|
|
|
- let user_id = user_id.read().clone();
|
|
|
- if user_id.is_none() {
|
|
|
- continue;
|
|
|
- }
|
|
|
- let user_id = user_id.unwrap();
|
|
|
- let server = local_server.clone();
|
|
|
- let fut = || async move {
|
|
|
- let bytes = Bytes::from(message.data);
|
|
|
- let client_data = DocumentClientWSData::try_from(bytes).map_err(internal_error)?;
|
|
|
- let _ = server
|
|
|
- .handle_client_data(client_data, user_id)
|
|
|
- .await
|
|
|
- .map_err(internal_error)?;
|
|
|
- Ok::<(), FlowyError>(())
|
|
|
- };
|
|
|
- match fut().await {
|
|
|
- Ok(_) => {},
|
|
|
- Err(e) => tracing::error!("[LocalWebSocket] error: {:?}", e),
|
|
|
+ tokio::select! {
|
|
|
+ result = ws_receiver.recv() => {
|
|
|
+ match result {
|
|
|
+ Ok(message) => {
|
|
|
+ let user_id = user_id.read().clone();
|
|
|
+ handle_ws_raw_message(user_id, &local_server, message).await;
|
|
|
+ },
|
|
|
+ Err(e) => tracing::error!("[LocalWebSocket] error: {}", e),
|
|
|
}
|
|
|
+ }
|
|
|
+ _ = stop_rx.recv() => {
|
|
|
+ break
|
|
|
},
|
|
|
- Err(e) => tracing::error!("[LocalWebSocket] error: {}", e),
|
|
|
}
|
|
|
}
|
|
|
});
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+async fn handle_ws_raw_message(
|
|
|
+ user_id: Option<String>,
|
|
|
+ local_server: &Arc<LocalDocumentServer>,
|
|
|
+ message: WebSocketRawMessage,
|
|
|
+) {
|
|
|
+ let f = || async {
|
|
|
+ match user_id {
|
|
|
+ None => Ok(()),
|
|
|
+ Some(user_id) => {
|
|
|
+ let bytes = Bytes::from(message.data);
|
|
|
+ let client_data = DocumentClientWSData::try_from(bytes).map_err(internal_error)?;
|
|
|
+ let _ = local_server.handle_client_data(client_data, user_id).await?;
|
|
|
+ Ok::<(), FlowyError>(())
|
|
|
+ },
|
|
|
+ }
|
|
|
+ };
|
|
|
+ if let Err(e) = f().await {
|
|
|
+ tracing::error!("[LocalWebSocket] error: {:?}", e);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
impl FlowyRawWebSocket for LocalWebSocket {
|
|
|
fn initialize(&self) -> FutureResult<(), FlowyError> {
|
|
|
- let mut server_rx = self.server_rx.write().take().expect("Only take once");
|
|
|
+ let mut server_rx = self.local_server_rx.write().take().expect("Only take once");
|
|
|
let receivers = self.receivers.clone();
|
|
|
tokio::spawn(async move {
|
|
|
while let Some(message) = server_rx.recv().await {
|
|
|
match receivers.get(&message.module) {
|
|
|
None => tracing::error!("Can't find any handler for message: {:?}", message),
|
|
|
- Some(handler) => handler.receive_message(message.clone()),
|
|
|
+ Some(receiver) => receiver.receive_message(message.clone()),
|
|
|
}
|
|
|
}
|
|
|
});
|
|
@@ -97,7 +125,7 @@ impl FlowyRawWebSocket for LocalWebSocket {
|
|
|
|
|
|
fn start_connect(&self, addr: String, user_id: String) -> FutureResult<(), FlowyError> {
|
|
|
*self.user_id.write() = Some(user_id);
|
|
|
- self.spawn_client(addr);
|
|
|
+ self.spawn_client_ws_receiver(addr);
|
|
|
FutureResult::new(async { Ok(()) })
|
|
|
}
|
|
|
|