Przeglądaj źródła

retry ws connection using WsConnectAction

appflowy 3 lat temu
rodzic
commit
d70072ae9f

+ 8 - 8
rust-lib/flowy-infra/src/retry/future.rs

@@ -164,7 +164,7 @@ where
 }
 
 /// An action can be run multiple times and produces a future.
-pub trait Action {
+pub trait Action: Send + Sync {
     type Future: Future<Output = Result<Self::Item, Self::Error>>;
     type Item;
     type Error;
@@ -172,13 +172,13 @@ pub trait Action {
     fn run(&mut self) -> Self::Future;
 }
 
-impl<R, E, T: Future<Output = Result<R, E>>, F: FnMut() -> T> Action for F {
-    type Future = T;
-    type Item = R;
-    type Error = E;
-
-    fn run(&mut self) -> Self::Future { self() }
-}
+// impl<R, E, T: Future<Output = Result<R, E>>, F: FnMut() -> T> Action for F {
+//     type Future = T;
+//     type Item = R;
+//     type Error = E;
+//
+//     fn run(&mut self) -> Self::Future { self() }
+// }
 
 pub trait Condition<E> {
     fn should_retry(&mut self, error: &E) -> bool;

+ 1 - 1
rust-lib/flowy-user/src/services/user/user_session.rs

@@ -16,7 +16,7 @@ use flowy_database::{
     ExpressionMethods,
     UserDatabaseConnection,
 };
-use flowy_infra::kv::KV;
+use flowy_infra::{future::wrap_future, kv::KV};
 use flowy_net::config::ServerConfig;
 use flowy_sqlite::ConnectionPool;
 use flowy_ws::{WsController, WsMessage, WsMessageHandler};

+ 3 - 1
rust-lib/flowy-ws/src/connect.rs

@@ -21,7 +21,9 @@ pub struct WsConnectionFuture {
     msg_tx: Option<MsgSender>,
     ws_rx: Option<MsgReceiver>,
     #[pin]
-    fut: BoxFuture<'static, Result<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response), Error>>,
+    fut: Pin<
+        Box<dyn Future<Output = Result<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response), Error>> + Send + Sync>,
+    >,
 }
 
 impl WsConnectionFuture {

+ 114 - 63
rust-lib/flowy-ws/src/ws.rs

@@ -6,7 +6,12 @@ use crate::{
 };
 use bytes::Bytes;
 use dashmap::DashMap;
+use flowy_infra::{
+    future::{wrap_future, FnFuture},
+    retry::{Action, ExponentialBackoff, Retry},
+};
 use flowy_net::errors::ServerError;
+use futures::future::BoxFuture;
 use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
 use futures_core::{ready, Stream};
 use parking_lot::RwLock;
@@ -43,7 +48,7 @@ pub enum WsState {
 pub struct WsController {
     handlers: Handlers,
     state_notify: Arc<broadcast::Sender<WsState>>,
-    sender: RwLock<Option<Arc<WsSender>>>,
+    sender: Arc<RwLock<Option<Arc<WsSender>>>>,
 }
 
 impl WsController {
@@ -51,7 +56,7 @@ impl WsController {
         let (state_notify, _) = broadcast::channel(16);
         let controller = Self {
             handlers: DashMap::new(),
-            sender: RwLock::new(None),
+            sender: Arc::new(RwLock::new(None)),
             state_notify: Arc::new(state_notify),
         };
         controller
@@ -68,63 +73,50 @@ impl WsController {
 
     pub async fn connect(&self, addr: String) -> Result<(), ServerError> {
         let (ret, rx) = oneshot::channel::<Result<(), ServerError>>();
-        self._connect(addr.clone(), ret);
-        rx.await?
-    }
-
-    #[allow(dead_code)]
-    pub fn state_subscribe(&self) -> broadcast::Receiver<WsState> { self.state_notify.subscribe() }
 
-    pub fn sender(&self) -> Result<Arc<WsSender>, WsError> {
-        match &*self.sender.read() {
-            None => Err(WsError::internal().context("WsSender is not initialized, should call connect first")),
-            Some(sender) => Ok(sender.clone()),
-        }
-    }
-
-    fn _connect(&self, addr: String, ret: oneshot::Sender<Result<(), ServerError>>) {
-        log::debug!("🐴 ws connect: {}", &addr);
-        let (connection, handlers) = self.make_connect(addr.clone());
+        let action = WsConnectAction {
+            addr,
+            handlers: self.handlers.clone(),
+        };
+        let strategy = ExponentialBackoff::from_millis(100).take(3);
+        let retry = Retry::spawn(strategy, action);
+        let sender_holder = self.sender.clone();
         let state_notify = self.state_notify.clone();
-        let sender = self
-            .sender
-            .read()
-            .clone()
-            .expect("Sender should be not empty after calling make_connect");
+
         tokio::spawn(async move {
-            match connection.await {
-                Ok(stream) => {
+            match retry.await {
+                Ok(result) => {
+                    let WsConnectResult {
+                        stream,
+                        handlers_fut,
+                        sender,
+                    } = result;
+                    let sender = Arc::new(sender);
+                    *sender_holder.write() = Some(sender.clone());
+
                     let _ = state_notify.send(WsState::Connected(sender));
                     let _ = ret.send(Ok(()));
-                    spawn_stream_and_handlers(stream, handlers, state_notify).await;
+                    spawn_stream_and_handlers(stream, handlers_fut, state_notify).await;
                 },
                 Err(e) => {
+                    //
                     let _ = state_notify.send(WsState::Disconnected(e.clone()));
                     let _ = ret.send(Err(ServerError::internal().context(e)));
                 },
             }
         });
+
+        rx.await?
     }
 
-    fn make_connect(&self, addr: String) -> (WsConnectionFuture, WsHandlerFuture) {
-        //                Stream                             User
-        //               ┌───────────────┐                 ┌──────────────┐
-        // ┌──────┐      │  ┌─────────┐  │    ┌────────┐   │  ┌────────┐  │
-        // │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │  │
-        // └──────┘      │  └─────────┘  │    └────────┘   │  └────────┘  │
-        //     ▲         │               │                 │              │
-        //     │         │  ┌─────────┐  │    ┌────────┐   │  ┌────────┐  │
-        //     └─────────┼──│ws_write │◀─┼────│ ws_rx  │◀──┼──│ ws_tx  │  │
-        //               │  └─────────┘  │    └────────┘   │  └────────┘  │
-        //               └───────────────┘                 └──────────────┘
-        let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
-        let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
-        let handlers = self.handlers.clone();
-        *self.sender.write() = Some(Arc::new(WsSender { ws_tx }));
-        (
-            WsConnectionFuture::new(msg_tx, ws_rx, addr),
-            WsHandlerFuture::new(handlers, msg_rx),
-        )
+    #[allow(dead_code)]
+    pub fn state_subscribe(&self) -> broadcast::Receiver<WsState> { self.state_notify.subscribe() }
+
+    pub fn sender(&self) -> Result<Arc<WsSender>, WsError> {
+        match &*self.sender.read() {
+            None => Err(WsError::internal().context("WsSender is not initialized, should call connect first")),
+            Some(sender) => Ok(sender.clone()),
+        }
     }
 }
 
@@ -239,21 +231,80 @@ impl WsSender {
     }
 }
 
-// #[cfg(test)]
-// mod tests {
-//     use super::WsController;
-//
-//     #[tokio::test]
-//     async fn connect() {
-//         std::env::set_var("RUST_LOG", "Debug");
-//         env_logger::init();
-//
-//         let mut controller = WsController::new();
-//         let addr = format!("{}/123", flowy_net::config::WS_ADDR.as_str());
-//         let (a, b) = controller.make_connect(addr);
-//         tokio::select! {
-//             r = a => println!("write completed {:?}", r),
-//             _ = b => println!("read completed"),
-//         };
-//     }
-// }
+struct WsConnectAction {
+    addr: String,
+    handlers: Handlers,
+}
+
+struct WsConnectResult {
+    stream: WsStream,
+    handlers_fut: WsHandlerFuture,
+    sender: WsSender,
+}
+
+#[pin_project]
+struct WsConnectActionFut {
+    addr: String,
+    #[pin]
+    conn: WsConnectionFuture,
+    handlers_fut: Option<WsHandlerFuture>,
+    sender: Option<WsSender>,
+}
+
+impl WsConnectActionFut {
+    fn new(addr: String, handlers: Handlers) -> Self {
+        //                Stream                             User
+        //               ┌───────────────┐                 ┌──────────────┐
+        // ┌──────┐      │  ┌─────────┐  │    ┌────────┐   │  ┌────────┐  │
+        // │Server│──────┼─▶│ ws_read │──┼───▶│ msg_tx │───┼─▶│ msg_rx │  │
+        // └──────┘      │  └─────────┘  │    └────────┘   │  └────────┘  │
+        //     ▲         │               │                 │              │
+        //     │         │  ┌─────────┐  │    ┌────────┐   │  ┌────────┐  │
+        //     └─────────┼──│ws_write │◀─┼────│ ws_rx  │◀──┼──│ ws_tx  │  │
+        //               │  └─────────┘  │    └────────┘   │  └────────┘  │
+        //               └───────────────┘                 └──────────────┘
+        log::debug!("🐴 ws start connect: {}", &addr);
+        let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
+        let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
+        let sender = WsSender { ws_tx };
+        let handlers_fut = WsHandlerFuture::new(handlers, msg_rx);
+        let conn = WsConnectionFuture::new(msg_tx, ws_rx, addr.clone());
+        Self {
+            addr,
+            conn,
+            handlers_fut: Some(handlers_fut),
+            sender: Some(sender),
+        }
+    }
+}
+
+impl Future for WsConnectActionFut {
+    type Output = Result<WsConnectResult, WsError>;
+    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+        let mut this = self.project();
+        match ready!(this.conn.as_mut().poll(cx)) {
+            Ok(stream) => {
+                let handlers_fut = this.handlers_fut.take().expect("Only take once");
+                let sender = this.sender.take().expect("Only take once");
+                Poll::Ready(Ok(WsConnectResult {
+                    stream,
+                    handlers_fut,
+                    sender,
+                }))
+            },
+            Err(e) => Poll::Ready(Err(WsError::internal().context(e))),
+        }
+    }
+}
+
+impl Action for WsConnectAction {
+    type Future = Pin<Box<dyn Future<Output = Result<Self::Item, Self::Error>> + Send + Sync>>;
+    type Item = WsConnectResult;
+    type Error = WsError;
+
+    fn run(&mut self) -> Self::Future {
+        let addr = self.addr.clone();
+        let handlers = self.handlers.clone();
+        Box::pin(WsConnectActionFut::new(addr, handlers))
+    }
+}