|
@@ -1,13 +1,15 @@
|
|
|
use crate::{
|
|
|
- connect::{Retry, WsConnectionFuture},
|
|
|
+ connect::{WsConnectionFuture, WsStream},
|
|
|
errors::WsError,
|
|
|
WsMessage,
|
|
|
WsModule,
|
|
|
};
|
|
|
use bytes::Bytes;
|
|
|
-use flowy_net::errors::ServerError;
|
|
|
+use dashmap::DashMap;
|
|
|
+use flowy_net::errors::{internal_error, ServerError};
|
|
|
use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
|
|
|
use futures_core::{future::BoxFuture, ready, Stream};
|
|
|
+use parking_lot::RwLock;
|
|
|
use pin_project::pin_project;
|
|
|
use std::{
|
|
|
collections::HashMap,
|
|
@@ -17,7 +19,10 @@ use std::{
|
|
|
sync::Arc,
|
|
|
task::{Context, Poll},
|
|
|
};
|
|
|
-use tokio::{sync::RwLock, task::JoinHandle};
|
|
|
+use tokio::{
|
|
|
+ sync::{broadcast, oneshot},
|
|
|
+ task::JoinHandle,
|
|
|
+};
|
|
|
use tokio_tungstenite::tungstenite::{
|
|
|
protocol::{frame::coding::CloseCode, CloseFrame},
|
|
|
Message,
|
|
@@ -25,6 +30,8 @@ use tokio_tungstenite::tungstenite::{
|
|
|
|
|
|
pub type MsgReceiver = UnboundedReceiver<Message>;
|
|
|
pub type MsgSender = UnboundedSender<Message>;
|
|
|
+type Handlers = DashMap<WsModule, Arc<dyn WsMessageHandler>>;
|
|
|
+
|
|
|
pub trait WsMessageHandler: Sync + Send + 'static {
|
|
|
fn source(&self) -> WsModule;
|
|
|
fn receive_message(&self, msg: WsMessage);
|
|
@@ -46,6 +53,7 @@ impl WsStateNotify {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+#[derive(Clone)]
|
|
|
pub enum WsState {
|
|
|
Init,
|
|
|
Connected(Arc<WsSender>),
|
|
@@ -53,37 +61,23 @@ pub enum WsState {
|
|
|
}
|
|
|
|
|
|
pub struct WsController {
|
|
|
- handlers: HashMap<WsModule, Arc<dyn WsMessageHandler>>,
|
|
|
- state_notify: Arc<RwLock<WsStateNotify>>,
|
|
|
- #[allow(dead_code)]
|
|
|
- addr: Option<String>,
|
|
|
- sender: Option<Arc<WsSender>>,
|
|
|
+ handlers: Handlers,
|
|
|
+ state_notify: Arc<broadcast::Sender<WsState>>,
|
|
|
+ sender: RwLock<Option<Arc<WsSender>>>,
|
|
|
}
|
|
|
|
|
|
impl WsController {
|
|
|
pub fn new() -> Self {
|
|
|
- let state_notify = Arc::new(RwLock::new(WsStateNotify {
|
|
|
- state: WsState::Init,
|
|
|
- callback: None,
|
|
|
- }));
|
|
|
-
|
|
|
+ let (state_notify, _) = broadcast::channel(16);
|
|
|
let controller = Self {
|
|
|
- handlers: HashMap::new(),
|
|
|
- state_notify,
|
|
|
- addr: None,
|
|
|
- sender: None,
|
|
|
+ handlers: DashMap::new(),
|
|
|
+ sender: RwLock::new(None),
|
|
|
+ state_notify: Arc::new(state_notify),
|
|
|
};
|
|
|
controller
|
|
|
}
|
|
|
|
|
|
- pub async fn state_callback<SC>(&self, callback: SC)
|
|
|
- where
|
|
|
- SC: Fn(&WsState) + Send + Sync + 'static,
|
|
|
- {
|
|
|
- (self.state_notify.write().await).callback = Some(Arc::new(callback));
|
|
|
- }
|
|
|
-
|
|
|
- pub fn add_handler(&mut self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
|
|
|
+ pub fn add_handler(&self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
|
|
|
let source = handler.source();
|
|
|
if self.handlers.contains_key(&source) {
|
|
|
log::error!("WsSource's {:?} is already registered", source);
|
|
@@ -92,60 +86,47 @@ impl WsController {
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
- pub fn connect(&mut self, addr: String) -> Result<JoinHandle<()>, ServerError> { self._connect(addr.clone(), None) }
|
|
|
-
|
|
|
- pub fn connect_with_retry<F>(&mut self, addr: String, retry: Retry<F>) -> Result<JoinHandle<()>, ServerError>
|
|
|
- where
|
|
|
- F: Fn(&str) + Send + Sync + 'static,
|
|
|
- {
|
|
|
- self._connect(addr, Some(Box::pin(async { retry.await })))
|
|
|
+ pub async fn connect(&self, addr: String) -> Result<(), ServerError> {
|
|
|
+ let (ret, rx) = oneshot::channel::<Result<(), ServerError>>();
|
|
|
+ self._connect(addr.clone(), ret);
|
|
|
+ rx.await?
|
|
|
}
|
|
|
|
|
|
- pub fn get_sender(&self) -> Result<Arc<WsSender>, WsError> {
|
|
|
- match &self.sender {
|
|
|
+ #[allow(dead_code)]
|
|
|
+ pub fn state_subscribe(&self) -> broadcast::Receiver<WsState> { self.state_notify.subscribe() }
|
|
|
+
|
|
|
+ pub fn sender(&self) -> Result<Arc<WsSender>, WsError> {
|
|
|
+ match &*self.sender.read() {
|
|
|
None => Err(WsError::internal().context("WsSender is not initialized, should call connect first")),
|
|
|
Some(sender) => Ok(sender.clone()),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- fn _connect(&mut self, addr: String, retry: Option<BoxFuture<'static, ()>>) -> Result<JoinHandle<()>, ServerError> {
|
|
|
+ fn _connect(&self, addr: String, ret: oneshot::Sender<Result<(), ServerError>>) {
|
|
|
log::debug!("🐴 ws connect: {}", &addr);
|
|
|
let (connection, handlers) = self.make_connect(addr.clone());
|
|
|
let state_notify = self.state_notify.clone();
|
|
|
let sender = self
|
|
|
.sender
|
|
|
+ .read()
|
|
|
.clone()
|
|
|
.expect("Sender should be not empty after calling make_connect");
|
|
|
- Ok(tokio::spawn(async move {
|
|
|
+ tokio::spawn(async move {
|
|
|
match connection.await {
|
|
|
Ok(stream) => {
|
|
|
- state_notify.write().await.update_state(WsState::Connected(sender));
|
|
|
- tokio::select! {
|
|
|
- result = stream => {
|
|
|
- match result {
|
|
|
- Ok(_) => {},
|
|
|
- Err(e) => {
|
|
|
- // TODO: retry?
|
|
|
- log::error!("ws stream error {:?}", e);
|
|
|
- state_notify.write().await.update_state(WsState::Disconnected(e));
|
|
|
- }
|
|
|
- }
|
|
|
- },
|
|
|
- result = handlers => log::debug!("handlers completed {:?}", result),
|
|
|
- };
|
|
|
+ state_notify.send(WsState::Connected(sender));
|
|
|
+ ret.send(Ok(()));
|
|
|
+ spawn_steam_and_handlers(stream, handlers, state_notify).await;
|
|
|
},
|
|
|
Err(e) => {
|
|
|
- log::error!("ws connect {} failed {:?}", addr, e);
|
|
|
- state_notify.write().await.update_state(WsState::Disconnected(e));
|
|
|
- if let Some(retry) = retry {
|
|
|
- tokio::spawn(retry);
|
|
|
- }
|
|
|
+ state_notify.send(WsState::Disconnected(e.clone()));
|
|
|
+ ret.send(Err(ServerError::internal().context(e)));
|
|
|
},
|
|
|
}
|
|
|
- }))
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
- fn make_connect(&mut self, addr: String) -> (WsConnectionFuture, WsHandlerFuture) {
|
|
|
+ fn make_connect(&self, addr: String) -> (WsConnectionFuture, WsHandlerFuture) {
|
|
|
// Stream User
|
|
|
// ┌───────────────┐ ┌──────────────┐
|
|
|
// ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
|
|
@@ -159,8 +140,7 @@ impl WsController {
|
|
|
let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
|
|
|
let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
|
|
|
let handlers = self.handlers.clone();
|
|
|
- self.sender = Some(Arc::new(WsSender { ws_tx }));
|
|
|
- self.addr = Some(addr.clone());
|
|
|
+ *self.sender.write() = Some(Arc::new(WsSender { ws_tx }));
|
|
|
(
|
|
|
WsConnectionFuture::new(msg_tx, ws_rx, addr),
|
|
|
WsHandlerFuture::new(handlers, msg_rx),
|
|
@@ -168,17 +148,36 @@ impl WsController {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+async fn spawn_steam_and_handlers(
|
|
|
+ stream: WsStream,
|
|
|
+ handlers: WsHandlerFuture,
|
|
|
+ state_notify: Arc<broadcast::Sender<WsState>>,
|
|
|
+) {
|
|
|
+ tokio::select! {
|
|
|
+ result = stream => {
|
|
|
+ match result {
|
|
|
+ Ok(_) => {},
|
|
|
+ Err(e) => {
|
|
|
+ // TODO: retry?
|
|
|
+ log::error!("ws stream error {:?}", e);
|
|
|
+ state_notify.send(WsState::Disconnected(e));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ },
|
|
|
+ result = handlers => log::debug!("handlers completed {:?}", result),
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
#[pin_project]
|
|
|
pub struct WsHandlerFuture {
|
|
|
#[pin]
|
|
|
msg_rx: MsgReceiver,
|
|
|
- handlers: HashMap<WsModule, Arc<dyn WsMessageHandler>>,
|
|
|
+ // Opti: Hashmap would be better
|
|
|
+ handlers: Handlers,
|
|
|
}
|
|
|
|
|
|
impl WsHandlerFuture {
|
|
|
- fn new(handlers: HashMap<WsModule, Arc<dyn WsMessageHandler>>, msg_rx: MsgReceiver) -> Self {
|
|
|
- Self { msg_rx, handlers }
|
|
|
- }
|
|
|
+ fn new(handlers: Handlers, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
|
|
|
|
|
|
fn handler_ws_message(&self, message: Message) {
|
|
|
match message {
|