ref_map.rs 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. use async_trait::async_trait;
  2. use std::collections::HashMap;
  3. use std::sync::Arc;
  4. #[async_trait]
  5. pub trait RefCountValue {
  6. async fn did_remove(&self) {}
  7. }
  8. struct RefCountHandler<T> {
  9. ref_count: usize,
  10. inner: T,
  11. }
  12. impl<T> RefCountHandler<T> {
  13. pub fn new(inner: T) -> Self {
  14. Self { ref_count: 1, inner }
  15. }
  16. pub fn increase_ref_count(&mut self) {
  17. self.ref_count += 1;
  18. }
  19. }
  20. pub struct RefCountHashMap<T>(HashMap<String, RefCountHandler<T>>);
  21. impl<T> std::default::Default for RefCountHashMap<T> {
  22. fn default() -> Self {
  23. Self(HashMap::new())
  24. }
  25. }
  26. impl<T> RefCountHashMap<T>
  27. where
  28. T: Clone + Send + Sync + RefCountValue + 'static,
  29. {
  30. pub fn new() -> Self {
  31. Self::default()
  32. }
  33. pub fn get(&self, key: &str) -> Option<T> {
  34. self.0.get(key).map(|handler| handler.inner.clone())
  35. }
  36. pub fn values(&self) -> Vec<T> {
  37. self.0.values().map(|value| value.inner.clone()).collect::<Vec<T>>()
  38. }
  39. pub fn insert(&mut self, key: String, value: T) {
  40. if let Some(handler) = self.0.get_mut(&key) {
  41. handler.increase_ref_count();
  42. } else {
  43. let handler = RefCountHandler::new(value);
  44. self.0.insert(key, handler);
  45. }
  46. }
  47. pub async fn remove(&mut self, key: &str) {
  48. let mut should_remove = false;
  49. if let Some(value) = self.0.get_mut(key) {
  50. if value.ref_count > 0 {
  51. value.ref_count -= 1;
  52. }
  53. should_remove = value.ref_count == 0;
  54. }
  55. if should_remove {
  56. if let Some(handler) = self.0.remove(key) {
  57. tokio::spawn(async move {
  58. handler.inner.did_remove().await;
  59. });
  60. }
  61. }
  62. }
  63. }
  64. #[async_trait]
  65. impl<T> RefCountValue for Arc<T>
  66. where
  67. T: RefCountValue + Sync + Send,
  68. {
  69. async fn did_remove(&self) {
  70. (**self).did_remove().await
  71. }
  72. }