|
@@ -11,7 +11,6 @@ use dashmap::DashMap;
|
|
use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
|
|
use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
|
|
use futures_core::{ready, Stream};
|
|
use futures_core::{ready, Stream};
|
|
use lib_infra::retry::{Action, FixedInterval, Retry};
|
|
use lib_infra::retry::{Action, FixedInterval, Retry};
|
|
-use parking_lot::RwLock;
|
|
|
|
use pin_project::pin_project;
|
|
use pin_project::pin_project;
|
|
use std::{
|
|
use std::{
|
|
convert::TryFrom,
|
|
convert::TryFrom,
|
|
@@ -22,7 +21,7 @@ use std::{
|
|
task::{Context, Poll},
|
|
task::{Context, Poll},
|
|
time::Duration,
|
|
time::Duration,
|
|
};
|
|
};
|
|
-use tokio::sync::{broadcast, oneshot};
|
|
|
|
|
|
+use tokio::sync::{broadcast, oneshot, RwLock};
|
|
use tokio_tungstenite::tungstenite::{
|
|
use tokio_tungstenite::tungstenite::{
|
|
protocol::{frame::coding::CloseCode, CloseFrame},
|
|
protocol::{frame::coding::CloseCode, CloseFrame},
|
|
Message,
|
|
Message,
|
|
@@ -39,19 +38,22 @@ pub trait WSMessageReceiver: Sync + Send + 'static {
|
|
|
|
|
|
pub struct WSController {
|
|
pub struct WSController {
|
|
handlers: Handlers,
|
|
handlers: Handlers,
|
|
- state_notify: Arc<broadcast::Sender<WSConnectState>>,
|
|
|
|
- sender_ctrl: Arc<RwLock<WSSenderController>>,
|
|
|
|
addr: Arc<RwLock<Option<String>>>,
|
|
addr: Arc<RwLock<Option<String>>>,
|
|
|
|
+ sender: Arc<RwLock<Option<Arc<WSSender>>>>,
|
|
|
|
+ conn_state_notify: Arc<RwLock<WSConnectStateNotifier>>,
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+impl std::fmt::Display for WSController {
|
|
|
|
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_str("WebSocket") }
|
|
}
|
|
}
|
|
|
|
|
|
impl std::default::Default for WSController {
|
|
impl std::default::Default for WSController {
|
|
fn default() -> Self {
|
|
fn default() -> Self {
|
|
- let (state_notify, _) = broadcast::channel(16);
|
|
|
|
Self {
|
|
Self {
|
|
handlers: DashMap::new(),
|
|
handlers: DashMap::new(),
|
|
- sender_ctrl: Arc::new(RwLock::new(WSSenderController::default())),
|
|
|
|
- state_notify: Arc::new(state_notify),
|
|
|
|
addr: Arc::new(RwLock::new(None)),
|
|
addr: Arc::new(RwLock::new(None)),
|
|
|
|
+ sender: Arc::new(RwLock::new(None)),
|
|
|
|
+ conn_state_notify: Arc::new(RwLock::new(WSConnectStateNotifier::default())),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -62,36 +64,52 @@ impl WSController {
|
|
pub fn add_ws_message_receiver(&self, handler: Arc<dyn WSMessageReceiver>) -> Result<(), WSError> {
|
|
pub fn add_ws_message_receiver(&self, handler: Arc<dyn WSMessageReceiver>) -> Result<(), WSError> {
|
|
let source = handler.source();
|
|
let source = handler.source();
|
|
if self.handlers.contains_key(&source) {
|
|
if self.handlers.contains_key(&source) {
|
|
- log::error!("WsSource's {:?} is already registered", source);
|
|
|
|
|
|
+ log::error!("{:?} is already registered", source);
|
|
}
|
|
}
|
|
self.handlers.insert(source, handler);
|
|
self.handlers.insert(source, handler);
|
|
Ok(())
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
|
|
pub async fn start(&self, addr: String) -> Result<(), ServerError> {
|
|
pub async fn start(&self, addr: String) -> Result<(), ServerError> {
|
|
- *self.addr.write() = Some(addr.clone());
|
|
|
|
|
|
+ *self.addr.write().await = Some(addr.clone());
|
|
let strategy = FixedInterval::from_millis(5000).take(3);
|
|
let strategy = FixedInterval::from_millis(5000).take(3);
|
|
self.connect(addr, strategy).await
|
|
self.connect(addr, strategy).await
|
|
}
|
|
}
|
|
|
|
|
|
- pub async fn stop(&self) { self.sender_ctrl.write().set_state(WSConnectState::Disconnected); }
|
|
|
|
|
|
+ pub async fn stop(&self) {
|
|
|
|
+ if self.conn_state_notify.read().await.conn_state.is_connected() {
|
|
|
|
+ tracing::trace!("[{}] stop", self);
|
|
|
|
+ self.conn_state_notify
|
|
|
|
+ .write()
|
|
|
|
+ .await
|
|
|
|
+ .update_state(WSConnectState::Disconnected);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
|
|
async fn connect<T, I>(&self, addr: String, strategy: T) -> Result<(), ServerError>
|
|
async fn connect<T, I>(&self, addr: String, strategy: T) -> Result<(), ServerError>
|
|
where
|
|
where
|
|
T: IntoIterator<IntoIter = I, Item = Duration>,
|
|
T: IntoIterator<IntoIter = I, Item = Duration>,
|
|
I: Iterator<Item = Duration> + Send + 'static,
|
|
I: Iterator<Item = Duration> + Send + 'static,
|
|
{
|
|
{
|
|
|
|
+ let mut conn_state_notify = self.conn_state_notify.write().await;
|
|
|
|
+ let conn_state = conn_state_notify.conn_state.clone();
|
|
|
|
+ if conn_state.is_connected() || conn_state.is_connecting() {
|
|
|
|
+ return Ok(());
|
|
|
|
+ }
|
|
|
|
+
|
|
let (ret, rx) = oneshot::channel::<Result<(), ServerError>>();
|
|
let (ret, rx) = oneshot::channel::<Result<(), ServerError>>();
|
|
- *self.addr.write() = Some(addr.clone());
|
|
|
|
|
|
+ *self.addr.write().await = Some(addr.clone());
|
|
let action = WSConnectAction {
|
|
let action = WSConnectAction {
|
|
addr,
|
|
addr,
|
|
handlers: self.handlers.clone(),
|
|
handlers: self.handlers.clone(),
|
|
};
|
|
};
|
|
-
|
|
|
|
let retry = Retry::spawn(strategy, action);
|
|
let retry = Retry::spawn(strategy, action);
|
|
- let sender_ctrl = self.sender_ctrl.clone();
|
|
|
|
- sender_ctrl.write().set_state(WSConnectState::Connecting);
|
|
|
|
|
|
+ conn_state_notify.update_state(WSConnectState::Connecting);
|
|
|
|
+ drop(conn_state_notify);
|
|
|
|
|
|
|
|
+ let cloned_conn_state = self.conn_state_notify.clone();
|
|
|
|
+ let cloned_sender = self.sender.clone();
|
|
|
|
+ tracing::trace!("[{}] start connecting", self);
|
|
tokio::spawn(async move {
|
|
tokio::spawn(async move {
|
|
match retry.await {
|
|
match retry.await {
|
|
Ok(result) => {
|
|
Ok(result) => {
|
|
@@ -100,30 +118,36 @@ impl WSController {
|
|
handlers_fut,
|
|
handlers_fut,
|
|
sender,
|
|
sender,
|
|
} = result;
|
|
} = result;
|
|
- sender_ctrl.write().set_sender(sender);
|
|
|
|
- sender_ctrl.write().set_state(WSConnectState::Connected);
|
|
|
|
|
|
+
|
|
|
|
+ cloned_conn_state.write().await.update_state(WSConnectState::Connected);
|
|
|
|
+ *cloned_sender.write().await = Some(Arc::new(sender));
|
|
|
|
+
|
|
let _ = ret.send(Ok(()));
|
|
let _ = ret.send(Ok(()));
|
|
- spawn_stream_and_handlers(stream, handlers_fut, sender_ctrl.clone()).await;
|
|
|
|
|
|
+ spawn_stream_and_handlers(stream, handlers_fut).await;
|
|
},
|
|
},
|
|
Err(e) => {
|
|
Err(e) => {
|
|
- sender_ctrl.write().set_error(e.clone());
|
|
|
|
|
|
+ cloned_conn_state
|
|
|
|
+ .write()
|
|
|
|
+ .await
|
|
|
|
+ .update_state(WSConnectState::Disconnected);
|
|
let _ = ret.send(Err(ServerError::internal().context(e)));
|
|
let _ = ret.send(Err(ServerError::internal().context(e)));
|
|
},
|
|
},
|
|
}
|
|
}
|
|
});
|
|
});
|
|
-
|
|
|
|
rx.await?
|
|
rx.await?
|
|
}
|
|
}
|
|
|
|
|
|
pub async fn retry(&self, count: usize) -> Result<(), ServerError> {
|
|
pub async fn retry(&self, count: usize) -> Result<(), ServerError> {
|
|
- if !self.sender_ctrl.read().is_disconnected() {
|
|
|
|
|
|
+ if !self.conn_state_notify.read().await.conn_state.is_disconnected() {
|
|
return Ok(());
|
|
return Ok(());
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ tracing::trace!("[WebSocket]: retry connect...");
|
|
let strategy = FixedInterval::from_millis(5000).take(count);
|
|
let strategy = FixedInterval::from_millis(5000).take(count);
|
|
let addr = self
|
|
let addr = self
|
|
.addr
|
|
.addr
|
|
.read()
|
|
.read()
|
|
|
|
+ .await
|
|
.as_ref()
|
|
.as_ref()
|
|
.expect("Retry web socket connection failed, should call start_connect first")
|
|
.expect("Retry web socket connection failed, should call start_connect first")
|
|
.clone();
|
|
.clone();
|
|
@@ -131,25 +155,30 @@ impl WSController {
|
|
self.connect(addr, strategy).await
|
|
self.connect(addr, strategy).await
|
|
}
|
|
}
|
|
|
|
|
|
- pub fn subscribe_state(&self) -> broadcast::Receiver<WSConnectState> { self.state_notify.subscribe() }
|
|
|
|
|
|
+ pub async fn subscribe_state(&self) -> broadcast::Receiver<WSConnectState> {
|
|
|
|
+ self.conn_state_notify.read().await.notify.subscribe()
|
|
|
|
+ }
|
|
|
|
|
|
- pub fn ws_message_sender(&self) -> Result<Arc<WSSender>, WSError> {
|
|
|
|
- match self.sender_ctrl.read().sender() {
|
|
|
|
- None => Err(WSError::internal().context("WebSocket is not initialized, should call connect first")),
|
|
|
|
- Some(sender) => Ok(sender),
|
|
|
|
|
|
+ pub async fn ws_message_sender(&self) -> Result<Option<Arc<WSSender>>, WSError> {
|
|
|
|
+ let sender = self.sender.read().await.clone();
|
|
|
|
+ match sender {
|
|
|
|
+ None => match self.conn_state_notify.read().await.conn_state {
|
|
|
|
+ WSConnectState::Disconnected => {
|
|
|
|
+ let msg = "WebSocket is disconnected";
|
|
|
|
+ Err(WSError::internal().context(msg))
|
|
|
|
+ },
|
|
|
|
+ _ => Ok(None),
|
|
|
|
+ },
|
|
|
|
+ Some(sender) => Ok(Some(sender)),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-async fn spawn_stream_and_handlers(
|
|
|
|
- stream: WSStream,
|
|
|
|
- handlers: WSHandlerFuture,
|
|
|
|
- sender_ctrl: Arc<RwLock<WSSenderController>>,
|
|
|
|
-) {
|
|
|
|
|
|
+async fn spawn_stream_and_handlers(stream: WSStream, handlers: WSHandlerFuture) {
|
|
tokio::select! {
|
|
tokio::select! {
|
|
result = stream => {
|
|
result = stream => {
|
|
if let Err(e) = result {
|
|
if let Err(e) = result {
|
|
- sender_ctrl.write().set_error(e);
|
|
|
|
|
|
+ tracing::error!("WSStream error: {:?}", e);
|
|
}
|
|
}
|
|
},
|
|
},
|
|
result = handlers => tracing::debug!("handlers completed {:?}", result),
|
|
result = handlers => tracing::debug!("handlers completed {:?}", result),
|
|
@@ -201,15 +230,13 @@ impl Future for WSHandlerFuture {
|
|
}
|
|
}
|
|
|
|
|
|
#[derive(Debug, Clone)]
|
|
#[derive(Debug, Clone)]
|
|
-pub struct WSSender {
|
|
|
|
- ws_tx: MsgSender,
|
|
|
|
-}
|
|
|
|
|
|
+pub struct WSSender(MsgSender);
|
|
|
|
|
|
impl WSSender {
|
|
impl WSSender {
|
|
pub fn send_msg<T: Into<WebSocketRawMessage>>(&self, msg: T) -> Result<(), WSError> {
|
|
pub fn send_msg<T: Into<WebSocketRawMessage>>(&self, msg: T) -> Result<(), WSError> {
|
|
let msg = msg.into();
|
|
let msg = msg.into();
|
|
let _ = self
|
|
let _ = self
|
|
- .ws_tx
|
|
|
|
|
|
+ .0
|
|
.unbounded_send(msg.into())
|
|
.unbounded_send(msg.into())
|
|
.map_err(|e| WSError::internal().context(e))?;
|
|
.map_err(|e| WSError::internal().context(e))?;
|
|
Ok(())
|
|
Ok(())
|
|
@@ -237,10 +264,7 @@ impl WSSender {
|
|
reason: reason.to_owned().into(),
|
|
reason: reason.to_owned().into(),
|
|
};
|
|
};
|
|
let msg = Message::Close(Some(frame));
|
|
let msg = Message::Close(Some(frame));
|
|
- let _ = self
|
|
|
|
- .ws_tx
|
|
|
|
- .unbounded_send(msg)
|
|
|
|
- .map_err(|e| WSError::internal().context(e))?;
|
|
|
|
|
|
+ let _ = self.0.unbounded_send(msg).map_err(|e| WSError::internal().context(e))?;
|
|
Ok(())
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -291,7 +315,7 @@ impl WSConnectActionFut {
|
|
// └───────────────┘ └──────────────┘
|
|
// └───────────────┘ └──────────────┘
|
|
let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
|
|
let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
|
|
let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
|
|
let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
|
|
- let sender = WSSender { ws_tx };
|
|
|
|
|
|
+ let sender = WSSender(ws_tx);
|
|
let handlers_fut = WSHandlerFuture::new(handlers, msg_rx);
|
|
let handlers_fut = WSHandlerFuture::new(handlers, msg_rx);
|
|
let conn = WSConnectionFuture::new(msg_tx, ws_rx, addr.clone());
|
|
let conn = WSConnectionFuture::new(msg_tx, ws_rx, addr.clone());
|
|
Self {
|
|
Self {
|
|
@@ -330,12 +354,20 @@ pub enum WSConnectState {
|
|
Disconnected,
|
|
Disconnected,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+impl WSConnectState {
|
|
|
|
+ fn is_connected(&self) -> bool { self == &WSConnectState::Connected }
|
|
|
|
+
|
|
|
|
+ fn is_connecting(&self) -> bool { self == &WSConnectState::Connecting }
|
|
|
|
+
|
|
|
|
+ fn is_disconnected(&self) -> bool { self == &WSConnectState::Disconnected || self == &WSConnectState::Init }
|
|
|
|
+}
|
|
|
|
+
|
|
impl std::fmt::Display for WSConnectState {
|
|
impl std::fmt::Display for WSConnectState {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
match self {
|
|
WSConnectState::Init => f.write_str("Init"),
|
|
WSConnectState::Init => f.write_str("Init"),
|
|
- WSConnectState::Connected => f.write_str("Connecting"),
|
|
|
|
- WSConnectState::Connecting => f.write_str("Connected"),
|
|
|
|
|
|
+ WSConnectState::Connected => f.write_str("Connected"),
|
|
|
|
+ WSConnectState::Connecting => f.write_str("Connecting"),
|
|
WSConnectState::Disconnected => f.write_str("Disconnected"),
|
|
WSConnectState::Disconnected => f.write_str("Disconnected"),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -345,44 +377,32 @@ impl std::fmt::Debug for WSConnectState {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_str(&format!("{}", self)) }
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_str(&format!("{}", self)) }
|
|
}
|
|
}
|
|
|
|
|
|
-struct WSSenderController {
|
|
|
|
- state: WSConnectState,
|
|
|
|
- state_notify: Arc<broadcast::Sender<WSConnectState>>,
|
|
|
|
- sender: Option<Arc<WSSender>>,
|
|
|
|
|
|
+struct WSConnectStateNotifier {
|
|
|
|
+ conn_state: WSConnectState,
|
|
|
|
+ notify: Arc<broadcast::Sender<WSConnectState>>,
|
|
}
|
|
}
|
|
|
|
|
|
-impl WSSenderController {
|
|
|
|
- fn set_sender(&mut self, sender: WSSender) { self.sender = Some(Arc::new(sender)); }
|
|
|
|
-
|
|
|
|
- fn set_state(&mut self, state: WSConnectState) {
|
|
|
|
- if state != WSConnectState::Connected {
|
|
|
|
- self.sender = None;
|
|
|
|
|
|
+impl std::default::Default for WSConnectStateNotifier {
|
|
|
|
+ fn default() -> Self {
|
|
|
|
+ let (state_notify, _) = broadcast::channel(16);
|
|
|
|
+ Self {
|
|
|
|
+ conn_state: WSConnectState::Init,
|
|
|
|
+ notify: Arc::new(state_notify),
|
|
}
|
|
}
|
|
-
|
|
|
|
- self.state = state;
|
|
|
|
- let _ = self.state_notify.send(self.state.clone());
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- fn set_error(&mut self, error: WSError) {
|
|
|
|
- log::error!("{:?}", error);
|
|
|
|
- self.set_state(WSConnectState::Disconnected);
|
|
|
|
}
|
|
}
|
|
-
|
|
|
|
- fn sender(&self) -> Option<Arc<WSSender>> { self.sender.clone() }
|
|
|
|
-
|
|
|
|
- #[allow(dead_code)]
|
|
|
|
- fn is_connecting(&self) -> bool { self.state == WSConnectState::Connecting }
|
|
|
|
-
|
|
|
|
- fn is_disconnected(&self) -> bool { self.state == WSConnectState::Disconnected }
|
|
|
|
}
|
|
}
|
|
|
|
|
|
-impl std::default::Default for WSSenderController {
|
|
|
|
- fn default() -> Self {
|
|
|
|
- let (state_notify, _) = broadcast::channel(16);
|
|
|
|
- WSSenderController {
|
|
|
|
- state: WSConnectState::Init,
|
|
|
|
- state_notify: Arc::new(state_notify),
|
|
|
|
- sender: None,
|
|
|
|
|
|
+impl WSConnectStateNotifier {
|
|
|
|
+ fn update_state(&mut self, new_state: WSConnectState) {
|
|
|
|
+ if self.conn_state == new_state {
|
|
|
|
+ return;
|
|
}
|
|
}
|
|
|
|
+ tracing::debug!(
|
|
|
|
+ "WebSocket connect state did change: {} -> {}",
|
|
|
|
+ self.conn_state,
|
|
|
|
+ new_state
|
|
|
|
+ );
|
|
|
|
+ self.conn_state = new_state.clone();
|
|
|
|
+ let _ = self.notify.send(new_state);
|
|
}
|
|
}
|
|
}
|
|
}
|