123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- use crate::{connect::WsConnection, errors::WsError, WsMessage};
- use flowy_net::errors::ServerError;
- use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
- use futures_core::{ready, Stream};
- use crate::connect::Retry;
- use futures_core::future::BoxFuture;
- use pin_project::pin_project;
- use std::{
- collections::HashMap,
- future::Future,
- marker::PhantomData,
- pin::Pin,
- sync::Arc,
- task::{Context, Poll},
- };
- use tokio::task::JoinHandle;
- use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
- pub type MsgReceiver = UnboundedReceiver<Message>;
- pub type MsgSender = UnboundedSender<Message>;
- pub trait WsMessageHandler: Sync + Send + 'static {
- fn source(&self) -> String;
- fn receive_message(&self, msg: WsMessage);
- }
- pub struct WsController {
- sender: Option<Arc<WsSender>>,
- handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
- addr: Option<String>,
- }
- impl WsController {
- pub fn new() -> Self {
- let controller = Self {
- sender: None,
- handlers: HashMap::new(),
- addr: None,
- };
- controller
- }
- pub fn add_handler(&mut self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
- let source = handler.source();
- if self.handlers.contains_key(&source) {
- return Err(WsError::duplicate_source());
- }
- self.handlers.insert(source, handler);
- Ok(())
- }
- pub fn connect(&mut self, addr: String) -> Result<JoinHandle<()>, ServerError> { self._connect(addr.clone(), None) }
- pub fn connect_with_retry<F>(&mut self, addr: String, retry: Retry<F>) -> Result<JoinHandle<()>, ServerError>
- where
- F: Fn(&str) + Send + Sync + 'static,
- {
- self._connect(addr, Some(Box::pin(async { retry.await })))
- }
- fn _connect(&mut self, addr: String, retry: Option<BoxFuture<'static, ()>>) -> Result<JoinHandle<()>, ServerError> {
- log::debug!("๐ด ws connect: {}", &addr);
- let (connection, handlers) = self.make_connect(addr.clone());
- Ok(tokio::spawn(async move {
- match connection.await {
- Ok(stream) => {
- tokio::select! {
- result = stream => {
- match result {
- Ok(_) => {},
- Err(e) => {
- // TODO: retry?
- log::error!("ws stream error {:?}", e);
- }
- }
- },
- result = handlers => log::debug!("handlers completed {:?}", result),
- };
- },
- Err(e) => match retry {
- None => log::error!("ws connect {} failed {:?}", addr, e),
- Some(retry) => {
- tokio::spawn(retry);
- },
- },
- }
- }))
- }
- fn make_connect(&mut self, addr: String) -> (WsConnection, WsHandlers) {
- // Stream User
- // โโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโ
- // โโโโโโโโ โ โโโโโโโโโโโ โ โโโโโโโโโโ โ โโโโโโโโโโ โ
- // โServerโโโโโโโโผโโถโ ws_read โโโโผโโโโถโ msg_tx โโโโโผโโถโ msg_rx โ โ
- // โโโโโโโโ โ โโโโโโโโโโโ โ โโโโโโโโโโ โ โโโโโโโโโโ โ
- // โฒ โ โ โ โ
- // โ โ โโโโโโโโโโโ โ โโโโโโโโโโ โ โโโโโโโโโโ โ
- // โโโโโโโโโโโผโโโws_write โโโโผโโโโโ ws_rx โโโโโผโโโ ws_tx โ โ
- // โ โโโโโโโโโโโ โ โโโโโโโโโโ โ โโโโโโโโโโ โ
- // โโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโ
- let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
- let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
- let handlers = self.handlers.clone();
- self.sender = Some(Arc::new(WsSender::new(ws_tx)));
- self.addr = Some(addr.clone());
- (WsConnection::new(msg_tx, ws_rx, addr), WsHandlers::new(handlers, msg_rx))
- }
- pub fn send_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), WsError> {
- match self.sender.as_ref() {
- None => Err(WsError::internal().context("Should call make_connect first")),
- Some(sender) => sender.send(msg.into()),
- }
- }
- }
- #[pin_project]
- pub struct WsHandlers {
- #[pin]
- msg_rx: MsgReceiver,
- handlers: HashMap<String, Arc<dyn WsMessageHandler>>,
- }
- impl WsHandlers {
- fn new(handlers: HashMap<String, Arc<dyn WsMessageHandler>>, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
- }
- impl Future for WsHandlers {
- type Output = ();
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
- loop {
- match ready!(self.as_mut().project().msg_rx.poll_next(cx)) {
- None => {
- return Poll::Ready(());
- },
- Some(message) => {
- log::debug!("๐ด ws handler receive message");
- let message = WsMessage::from(message);
- match self.handlers.get(&message.source) {
- None => log::error!("Can't find any handler for message: {:?}", message),
- Some(handler) => handler.receive_message(message.clone()),
- }
- },
- }
- }
- }
- }
- struct WsSender {
- ws_tx: MsgSender,
- }
- impl WsSender {
- pub fn new(ws_tx: MsgSender) -> Self { Self { ws_tx } }
- pub fn send(&self, msg: WsMessage) -> Result<(), WsError> {
- let _ = self.ws_tx.unbounded_send(msg.into()).map_err(|e| WsError::internal().context(e))?;
- Ok(())
- }
- }
- #[cfg(test)]
- mod tests {
- use super::WsController;
- #[tokio::test]
- async fn connect() {
- std::env::set_var("RUST_LOG", "Debug");
- env_logger::init();
- let mut controller = WsController::new();
- let addr = format!("{}/123", flowy_net::config::WS_ADDR.as_str());
- let (a, b) = controller.make_connect(addr);
- tokio::select! {
- r = a => println!("write completed {:?}", r),
- _ = b => println!("read completed"),
- };
- }
- }
|