connect.rs 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. #![allow(clippy::all)]
  2. use crate::{
  3. errors::{internal_error, WSError},
  4. MsgReceiver, MsgSender,
  5. };
  6. use futures_core::{future::BoxFuture, ready};
  7. use futures_util::{FutureExt, StreamExt};
  8. use pin_project::pin_project;
  9. use std::{
  10. fmt,
  11. future::Future,
  12. pin::Pin,
  13. task::{Context, Poll},
  14. };
  15. use tokio::net::TcpStream;
  16. use tokio_tungstenite::{
  17. connect_async,
  18. tungstenite::{handshake::client::Response, Error, Message},
  19. MaybeTlsStream, WebSocketStream,
  20. };
  21. type WsConnectResult = Result<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response), Error>;
  22. #[pin_project]
  23. pub struct WSConnectionFuture {
  24. msg_tx: Option<MsgSender>,
  25. ws_rx: Option<MsgReceiver>,
  26. #[pin]
  27. fut: Pin<Box<dyn Future<Output = WsConnectResult> + Send + Sync>>,
  28. }
  29. impl WSConnectionFuture {
  30. pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, addr: String) -> Self {
  31. WSConnectionFuture {
  32. msg_tx: Some(msg_tx),
  33. ws_rx: Some(ws_rx),
  34. fut: Box::pin(async move { connect_async(&addr).await }),
  35. }
  36. }
  37. }
  38. impl Future for WSConnectionFuture {
  39. type Output = Result<WSStream, WSError>;
  40. fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  41. // [[pin]]
  42. // poll async function. The following methods not work.
  43. // 1.
  44. // let f = connect_async("");
  45. // pin_mut!(f);
  46. // ready!(Pin::new(&mut a).poll(cx))
  47. //
  48. // 2.ready!(Pin::new(&mut Box::pin(connect_async(""))).poll(cx))
  49. //
  50. // An async method calls poll multiple times and might return to the executor. A
  51. // single poll call can only return to the executor once and will get
  52. // resumed through another poll invocation. the connect_async call multiple time
  53. // from the beginning. So I use fut to hold the future and continue to
  54. // poll it. (Fix me if i was wrong)
  55. loop {
  56. return match ready!(self.as_mut().project().fut.poll(cx)) {
  57. Ok((stream, _)) => {
  58. tracing::debug!("[WebSocket]: connect success");
  59. let (msg_tx, ws_rx) = (
  60. self
  61. .msg_tx
  62. .take()
  63. .expect("[WebSocket]: WSConnection should be call once "),
  64. self
  65. .ws_rx
  66. .take()
  67. .expect("[WebSocket]: WSConnection should be call once "),
  68. );
  69. Poll::Ready(Ok(WSStream::new(msg_tx, ws_rx, stream)))
  70. },
  71. Err(error) => {
  72. tracing::debug!("[WebSocket]: ❌ connect failed: {:?}", error);
  73. Poll::Ready(Err(error.into()))
  74. },
  75. };
  76. }
  77. }
  78. }
  79. type Fut = BoxFuture<'static, Result<(), WSError>>;
  80. #[pin_project]
  81. pub struct WSStream {
  82. #[allow(dead_code)]
  83. msg_tx: MsgSender,
  84. #[pin]
  85. inner: Option<(Fut, Fut)>,
  86. }
  87. impl WSStream {
  88. pub fn new(
  89. msg_tx: MsgSender,
  90. ws_rx: MsgReceiver,
  91. stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
  92. ) -> Self {
  93. let (ws_write, ws_read) = stream.split();
  94. Self {
  95. msg_tx: msg_tx.clone(),
  96. inner: Some((
  97. Box::pin(async move {
  98. let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
  99. let read = async {
  100. ws_read
  101. .for_each(|message| async {
  102. match tx.send(send_message(msg_tx.clone(), message)) {
  103. Ok(_) => {},
  104. Err(e) => log::error!("[WebSocket]: WSStream sender closed unexpectedly: {} ", e),
  105. }
  106. })
  107. .await;
  108. Ok(())
  109. };
  110. let read_ret = async {
  111. loop {
  112. match rx.recv().await {
  113. None => {
  114. return Err(
  115. WSError::internal()
  116. .context("[WebSocket]: WSStream receiver closed unexpectedly"),
  117. );
  118. },
  119. Some(result) => {
  120. if result.is_err() {
  121. return result;
  122. }
  123. },
  124. }
  125. }
  126. };
  127. futures::pin_mut!(read);
  128. futures::pin_mut!(read_ret);
  129. return tokio::select! {
  130. result = read => result,
  131. result = read_ret => result,
  132. };
  133. }),
  134. Box::pin(async move {
  135. let result = ws_rx
  136. .map(Ok)
  137. .forward(ws_write)
  138. .await
  139. .map_err(internal_error);
  140. result
  141. }),
  142. )),
  143. }
  144. }
  145. }
  146. impl fmt::Debug for WSStream {
  147. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  148. f.debug_struct("WSStream").finish()
  149. }
  150. }
  151. impl Future for WSStream {
  152. type Output = Result<(), WSError>;
  153. fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
  154. let (mut ws_read, mut ws_write) = self.inner.take().unwrap();
  155. match ws_read.poll_unpin(cx) {
  156. Poll::Ready(l) => Poll::Ready(l),
  157. Poll::Pending => {
  158. //
  159. match ws_write.poll_unpin(cx) {
  160. Poll::Ready(r) => Poll::Ready(r),
  161. Poll::Pending => {
  162. self.inner = Some((ws_read, ws_write));
  163. Poll::Pending
  164. },
  165. }
  166. },
  167. }
  168. }
  169. }
  170. fn send_message(msg_tx: MsgSender, message: Result<Message, Error>) -> Result<(), WSError> {
  171. match message {
  172. Ok(Message::Binary(bytes)) => msg_tx
  173. .unbounded_send(Message::Binary(bytes))
  174. .map_err(internal_error),
  175. Ok(_) => Ok(()),
  176. Err(e) => Err(WSError::internal().context(e)),
  177. }
  178. }
  179. #[allow(dead_code)]
  180. pub struct Retry<F> {
  181. f: F,
  182. #[allow(dead_code)]
  183. retry_time: usize,
  184. addr: String,
  185. }
  186. impl<F> Retry<F>
  187. where
  188. F: Fn(&str),
  189. {
  190. #[allow(dead_code)]
  191. pub fn new(addr: &str, f: F) -> Self {
  192. Self {
  193. f,
  194. retry_time: 3,
  195. addr: addr.to_owned(),
  196. }
  197. }
  198. }
  199. impl<F> Future for Retry<F>
  200. where
  201. F: Fn(&str),
  202. {
  203. type Output = ();
  204. fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
  205. (self.f)(&self.addr);
  206. Poll::Ready(())
  207. }
  208. }