ref_map.rs 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. use async_trait::async_trait;
  2. use std::collections::HashMap;
  3. use std::sync::Arc;
  4. #[async_trait]
  5. pub trait RefCountValue {
  6. async fn did_remove(&self) {}
  7. }
  8. struct RefCountHandler<T> {
  9. ref_count: usize,
  10. inner: T,
  11. }
  12. impl<T> RefCountHandler<T> {
  13. pub fn new(inner: T) -> Self {
  14. Self {
  15. ref_count: 1,
  16. inner,
  17. }
  18. }
  19. pub fn increase_ref_count(&mut self) {
  20. self.ref_count += 1;
  21. }
  22. }
  23. pub struct RefCountHashMap<T>(HashMap<String, RefCountHandler<T>>);
  24. impl<T> std::default::Default for RefCountHashMap<T> {
  25. fn default() -> Self {
  26. Self(HashMap::new())
  27. }
  28. }
  29. impl<T> RefCountHashMap<T>
  30. where
  31. T: Clone + Send + Sync + RefCountValue + 'static,
  32. {
  33. pub fn new() -> Self {
  34. Self::default()
  35. }
  36. pub fn get(&self, key: &str) -> Option<T> {
  37. self.0.get(key).map(|handler| handler.inner.clone())
  38. }
  39. pub fn values(&self) -> Vec<T> {
  40. self
  41. .0
  42. .values()
  43. .map(|value| value.inner.clone())
  44. .collect::<Vec<T>>()
  45. }
  46. pub fn insert(&mut self, key: String, value: T) {
  47. if let Some(handler) = self.0.get_mut(&key) {
  48. handler.increase_ref_count();
  49. } else {
  50. let handler = RefCountHandler::new(value);
  51. self.0.insert(key, handler);
  52. }
  53. }
  54. pub async fn remove(&mut self, key: &str) {
  55. let mut should_remove = false;
  56. if let Some(value) = self.0.get_mut(key) {
  57. if value.ref_count > 0 {
  58. value.ref_count -= 1;
  59. }
  60. should_remove = value.ref_count == 0;
  61. }
  62. if should_remove {
  63. if let Some(handler) = self.0.remove(key) {
  64. tokio::spawn(async move {
  65. handler.inner.did_remove().await;
  66. });
  67. }
  68. }
  69. }
  70. }
  71. #[async_trait]
  72. impl<T> RefCountValue for Arc<T>
  73. where
  74. T: RefCountValue + Sync + Send,
  75. {
  76. async fn did_remove(&self) {
  77. (**self).did_remove().await
  78. }
  79. }