ws_manager.rs 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. use crate::{entities::ws::WsDocumentData, errors::DocError};
  2. use bytes::Bytes;
  3. use dashmap::DashMap;
  4. use flowy_ws::WsState;
  5. use std::{collections::HashMap, convert::TryInto, sync::Arc};
  6. use tokio::sync::broadcast::error::RecvError;
  7. pub(crate) trait WsDocumentHandler: Send + Sync {
  8. fn receive(&self, data: WsDocumentData);
  9. fn state_changed(&self, state: &WsState);
  10. }
  11. pub type WsStateReceiver = tokio::sync::broadcast::Receiver<WsState>;
  12. pub trait DocumentWebSocket: Send + Sync {
  13. fn send(&self, data: WsDocumentData) -> Result<(), DocError>;
  14. fn state_notify(&self) -> WsStateReceiver;
  15. }
  16. pub struct WsDocumentManager {
  17. ws: Arc<dyn DocumentWebSocket>,
  18. // key: the document id
  19. handlers: Arc<DashMap<String, Arc<dyn WsDocumentHandler>>>,
  20. }
  21. impl WsDocumentManager {
  22. pub fn new(ws: Arc<dyn DocumentWebSocket>) -> Self {
  23. let handlers: Arc<DashMap<String, Arc<dyn WsDocumentHandler>>> = Arc::new(DashMap::new());
  24. listen_ws_state_changed(ws.clone(), handlers.clone());
  25. Self { ws, handlers }
  26. }
  27. pub(crate) fn register_handler(&self, id: &str, handler: Arc<dyn WsDocumentHandler>) {
  28. if self.handlers.contains_key(id) {
  29. log::error!("Duplicate handler registered for {:?}", id);
  30. }
  31. self.handlers.insert(id.to_string(), handler);
  32. }
  33. pub(crate) fn remove_handler(&self, id: &str) { self.handlers.remove(id); }
  34. pub fn handle_ws_data(&self, data: Bytes) {
  35. let data: WsDocumentData = data.try_into().unwrap();
  36. match self.handlers.get(&data.doc_id) {
  37. None => {
  38. log::error!("Can't find any source handler for {:?}", data.doc_id);
  39. },
  40. Some(handler) => {
  41. handler.receive(data);
  42. },
  43. }
  44. }
  45. pub fn ws(&self) -> Arc<dyn DocumentWebSocket> { self.ws.clone() }
  46. }
  47. #[tracing::instrument(level = "debug", skip(ws, handlers))]
  48. fn listen_ws_state_changed(ws: Arc<dyn DocumentWebSocket>, handlers: Arc<DashMap<String, Arc<dyn WsDocumentHandler>>>) {
  49. let mut notify = ws.state_notify();
  50. tokio::spawn(async move {
  51. loop {
  52. match notify.recv().await {
  53. Ok(state) => {
  54. handlers.iter().for_each(|handle| {
  55. handle.value().state_changed(&state);
  56. });
  57. },
  58. Err(e) => {
  59. log::error!("Websocket state notify error: {:?}", e);
  60. break;
  61. },
  62. }
  63. }
  64. });
  65. }