ws.rs 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. use crate::ws::connection::{FlowyRawWebSocket, FlowyWebSocket};
  2. use dashmap::DashMap;
  3. use flowy_error::FlowyError;
  4. use futures_util::future::BoxFuture;
  5. use lib_infra::future::FutureResult;
  6. use lib_ws::{WSChannel, WSConnectState, WSMessageReceiver, WebSocketRawMessage};
  7. use parking_lot::RwLock;
  8. use std::sync::Arc;
  9. use tokio::sync::{broadcast, broadcast::Receiver, mpsc::UnboundedReceiver};
  10. pub struct LocalWebSocket {
  11. user_id: Arc<RwLock<Option<String>>>,
  12. receivers: Arc<DashMap<WSChannel, Arc<dyn WSMessageReceiver>>>,
  13. state_sender: broadcast::Sender<WSConnectState>,
  14. server_ws_receiver: RwLock<Option<UnboundedReceiver<WebSocketRawMessage>>>,
  15. server_ws_sender: broadcast::Sender<WebSocketRawMessage>,
  16. }
  17. impl LocalWebSocket {
  18. pub fn new(
  19. server_ws_receiver: UnboundedReceiver<WebSocketRawMessage>,
  20. server_ws_sender: broadcast::Sender<WebSocketRawMessage>,
  21. ) -> Self {
  22. let user_id = Arc::new(RwLock::new(None));
  23. let receivers = Arc::new(DashMap::new());
  24. let server_ws_receiver = RwLock::new(Some(server_ws_receiver));
  25. let (state_sender, _) = broadcast::channel(16);
  26. LocalWebSocket {
  27. user_id,
  28. receivers,
  29. state_sender,
  30. server_ws_receiver,
  31. server_ws_sender,
  32. }
  33. }
  34. }
  35. impl FlowyRawWebSocket for LocalWebSocket {
  36. fn initialize(&self) -> FutureResult<(), FlowyError> {
  37. let mut server_ws_receiver = self.server_ws_receiver.write().take().expect("Only take once");
  38. let receivers = self.receivers.clone();
  39. tokio::spawn(async move {
  40. while let Some(message) = server_ws_receiver.recv().await {
  41. match receivers.get(&message.channel) {
  42. None => tracing::error!("Can't find any handler for message: {:?}", message),
  43. Some(receiver) => receiver.receive_message(message.clone()),
  44. }
  45. }
  46. });
  47. FutureResult::new(async { Ok(()) })
  48. }
  49. fn start_connect(&self, _addr: String, user_id: String) -> FutureResult<(), FlowyError> {
  50. *self.user_id.write() = Some(user_id);
  51. FutureResult::new(async { Ok(()) })
  52. }
  53. fn stop_connect(&self) -> FutureResult<(), FlowyError> {
  54. FutureResult::new(async { Ok(()) })
  55. }
  56. fn subscribe_connect_state(&self) -> BoxFuture<Receiver<WSConnectState>> {
  57. let subscribe = self.state_sender.subscribe();
  58. Box::pin(async move { subscribe })
  59. }
  60. fn reconnect(&self, _count: usize) -> FutureResult<(), FlowyError> {
  61. FutureResult::new(async { Ok(()) })
  62. }
  63. fn add_msg_receiver(&self, receiver: Arc<dyn WSMessageReceiver>) -> Result<(), FlowyError> {
  64. tracing::trace!("Local web socket add ws receiver: {:?}", receiver.source());
  65. self.receivers.insert(receiver.source(), receiver);
  66. Ok(())
  67. }
  68. fn ws_msg_sender(&self) -> FutureResult<Option<Arc<dyn FlowyWebSocket>>, FlowyError> {
  69. let ws: Arc<dyn FlowyWebSocket> = Arc::new(LocalWebSocketAdaptor(self.server_ws_sender.clone()));
  70. FutureResult::new(async move { Ok(Some(ws)) })
  71. }
  72. }
  73. #[derive(Clone)]
  74. struct LocalWebSocketAdaptor(broadcast::Sender<WebSocketRawMessage>);
  75. impl FlowyWebSocket for LocalWebSocketAdaptor {
  76. fn send(&self, msg: WebSocketRawMessage) -> Result<(), FlowyError> {
  77. let _ = self.0.send(msg);
  78. Ok(())
  79. }
  80. }