From 937ce59896168c02c06b51335af4dbd5fe863edb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Fri, 27 Sep 2024 14:06:20 +0200 Subject: [PATCH] fix(#360): data not flushed immediatly on reverse tunnel --- src/tunnel/client/cnx_pool.rs | 5 +- src/tunnel/client/l4_transport_stream.rs | 60 +++++++++++++++++------- src/tunnel/server/handler_websocket.rs | 16 +++---- src/tunnel/transport/websocket.rs | 60 ++++++++++++++++++++---- 4 files changed, 103 insertions(+), 38 deletions(-) diff --git a/src/tunnel/client/cnx_pool.rs b/src/tunnel/client/cnx_pool.rs index 0cef719..95e498c 100644 --- a/src/tunnel/client/cnx_pool.rs +++ b/src/tunnel/client/cnx_pool.rs @@ -4,6 +4,7 @@ use crate::tunnel::client::l4_transport_stream::TransportStream; use crate::tunnel::client::WsClientConfig; use async_trait::async_trait; use bb8::ManageConnection; +use bytes::Bytes; use std::ops::Deref; use std::sync::Arc; use tracing::instrument; @@ -58,9 +59,9 @@ impl ManageConnection for WsConnection { if self.remote_addr.tls().is_some() { let tls_stream = tls::connect(self, tcp_stream).await?; - Ok(Some(TransportStream::Tls(tls_stream))) + Ok(Some(TransportStream::Tls(tls_stream, Bytes::default()))) } else { - Ok(Some(TransportStream::Plain(tcp_stream))) + Ok(Some(TransportStream::Plain(tcp_stream, Bytes::default()))) } } diff --git a/src/tunnel/client/l4_transport_stream.rs b/src/tunnel/client/l4_transport_stream.rs index bbf55e1..fb9e337 100644 --- a/src/tunnel/client/l4_transport_stream.rs +++ b/src/tunnel/client/l4_transport_stream.rs @@ -1,20 +1,43 @@ +use bytes::{Buf, Bytes}; +use std::cmp; use std::io::{Error, IoSlice}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::TcpStream; -use tokio_rustls::client::TlsStream; pub enum TransportStream { - Plain(TcpStream), - Tls(TlsStream), + Plain(TcpStream, Bytes), + Tls(tokio_rustls::client::TlsStream, Bytes), + TlsSrv(tokio_rustls::server::TlsStream, Bytes), +} + +impl TransportStream { + pub fn read_buf_mut(&mut self) -> &mut Bytes { + match self { + Self::Plain(_, buf) => buf, + Self::Tls(_, buf) => buf, + Self::TlsSrv(_, buf) => buf, + } + } } impl AsyncRead for TransportStream { fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { - match self.get_mut() { - Self::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf), - Self::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf), + let this = self.get_mut(); + + let read_buf = this.read_buf_mut(); + if !read_buf.is_empty() { + let copy_len = cmp::min(read_buf.len(), buf.remaining()); + buf.put_slice(&read_buf[..copy_len]); + read_buf.advance(copy_len); + return Poll::Ready(Ok(())); + } + + match this { + Self::Plain(cnx, _) => Pin::new(cnx).poll_read(cx, buf), + Self::Tls(cnx, _) => Pin::new(cnx).poll_read(cx, buf), + Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_read(cx, buf), } } } @@ -22,22 +45,25 @@ impl AsyncRead for TransportStream { impl AsyncWrite for TransportStream { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { match self.get_mut() { - Self::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf), - Self::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf), + Self::Plain(cnx, _) => Pin::new(cnx).poll_write(cx, buf), + Self::Tls(cnx, _) => Pin::new(cnx).poll_write(cx, buf), + Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_write(cx, buf), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - Self::Plain(cnx) => Pin::new(cnx).poll_flush(cx), - Self::Tls(cnx) => Pin::new(cnx).poll_flush(cx), + Self::Plain(cnx, _) => Pin::new(cnx).poll_flush(cx), + Self::Tls(cnx, _) => Pin::new(cnx).poll_flush(cx), + Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_flush(cx), } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { - Self::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx), - Self::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx), + Self::Plain(cnx, _) => Pin::new(cnx).poll_shutdown(cx), + Self::Tls(cnx, _) => Pin::new(cnx).poll_shutdown(cx), + Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_shutdown(cx), } } @@ -47,15 +73,17 @@ impl AsyncWrite for TransportStream { bufs: &[IoSlice<'_>], ) -> Poll> { match self.get_mut() { - Self::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs), - Self::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs), + Self::Plain(cnx, _) => Pin::new(cnx).poll_write_vectored(cx, bufs), + Self::Tls(cnx, _) => Pin::new(cnx).poll_write_vectored(cx, bufs), + Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_write_vectored(cx, bufs), } } fn is_write_vectored(&self) -> bool { match &self { - Self::Plain(cnx) => cnx.is_write_vectored(), - Self::Tls(cnx) => cnx.is_write_vectored(), + Self::Plain(cnx, _) => cnx.is_write_vectored(), + Self::Tls(cnx, _) => cnx.is_write_vectored(), + Self::TlsSrv(cnx, _) => cnx.is_write_vectored(), } } } diff --git a/src/tunnel/server/handler_websocket.rs b/src/tunnel/server/handler_websocket.rs index 1c749ae..5599c9b 100644 --- a/src/tunnel/server/handler_websocket.rs +++ b/src/tunnel/server/handler_websocket.rs @@ -2,8 +2,9 @@ use crate::restrictions::types::RestrictionsRules; use crate::tunnel::server::utils::{bad_request, inject_cookie}; use crate::tunnel::server::WsServer; use crate::tunnel::transport; -use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite}; +use crate::tunnel::transport::websocket::mk_websocket_tunnel; use bytes::Bytes; +use fastwebsockets::Role; use http_body_util::combinators::BoxBody; use http_body_util::Either; use hyper::body::Incoming; @@ -46,31 +47,26 @@ pub(super) async fn ws_server_upgrade( tokio::spawn( async move { let (ws_rx, ws_tx) = match fut.await { - Ok(mut ws) => { - ws.set_auto_pong(false); - ws.set_auto_close(false); - ws.set_auto_apply_mask(mask_frame); - ws.split(tokio::io::split) - } + Ok(ws) => mk_websocket_tunnel(ws, Role::Server, mask_frame)?, Err(err) => { error!("Error during http upgrade request: {:?}", err); - return; + return Err(anyhow::Error::from(err)); } }; let (close_tx, close_rx) = oneshot::channel::<()>(); - let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx); tokio::task::spawn( transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()), ); let _ = transport::io::propagate_local_to_remote( local_rx, - WebsocketTunnelWrite::new(ws_tx, pending_ops), + ws_tx, close_tx, server.config.websocket_ping_frequency, ) .await; + Ok(()) } .instrument(Span::current()), ); diff --git a/src/tunnel/transport/websocket.rs b/src/tunnel/transport/websocket.rs index a6d29c5..afbea74 100644 --- a/src/tunnel/transport/websocket.rs +++ b/src/tunnel/transport/websocket.rs @@ -1,11 +1,12 @@ use super::io::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH}; +use crate::tunnel::client::l4_transport_stream::TransportStream; use crate::tunnel::client::WsClient; use crate::tunnel::transport::headers_from_file; use crate::tunnel::transport::jwt::{tunnel_to_jwt_token, JWT_HEADER_PREFIX}; use crate::tunnel::RemoteAddr; use anyhow::{anyhow, Context}; use bytes::{Bytes, BytesMut}; -use fastwebsockets::{CloseCode, Frame, OpCode, Payload, WebSocketRead, WebSocketWrite}; +use fastwebsockets::{CloseCode, Frame, OpCode, Payload, Role, WebSocket, WebSocketRead, WebSocketWrite}; use http_body_util::Empty; use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE}; use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY}; @@ -22,13 +23,15 @@ use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; +use tokio::net::TcpStream; use tokio::sync::mpsc::{Receiver, Sender}; use tokio::sync::Notify; +use tokio_rustls::server::TlsStream; use tracing::trace; use uuid::Uuid; pub struct WebsocketTunnelWrite { - inner: WebSocketWrite>>, + inner: WebSocketWrite>, buf: BytesMut, pending_operations: Receiver>, pending_ops_notify: Arc, @@ -37,7 +40,7 @@ pub struct WebsocketTunnelWrite { impl WebsocketTunnelWrite { pub fn new( - ws: WebSocketWrite>>, + ws: WebSocketWrite>, (pending_operations, notify): (Receiver>, Arc), ) -> Self { Self { @@ -146,13 +149,13 @@ impl TunnelWrite for WebsocketTunnelWrite { } pub struct WebsocketTunnelRead { - inner: WebSocketRead>>, + inner: WebSocketRead>, pending_operations: Sender>, notify_pending_ops: Arc, } impl WebsocketTunnelRead { - pub fn new(ws: WebSocketRead>>) -> (Self, (Receiver>, Arc)) { + pub fn new(ws: WebSocketRead>) -> (Self, (Receiver>, Arc)) { let (tx, rx) = tokio::sync::mpsc::channel(10); let notify = Arc::new(Notify::new()); ( @@ -278,16 +281,53 @@ pub async fn connect( })?; debug!("with HTTP upgrade request {:?}", req); let transport = pooled_cnx.deref_mut().take().unwrap(); - let (mut ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport) + let (ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport) .await .with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?; - ws.set_auto_apply_mask(client_cfg.websocket_mask_frame); - ws.set_auto_close(false); - ws.set_auto_pong(false); + let (ws_rx, ws_tx) = mk_websocket_tunnel(ws, Role::Client, client_cfg.websocket_mask_frame)?; + Ok((ws_rx, ws_tx, response.into_parts().0)) +} +pub fn mk_websocket_tunnel( + ws: WebSocket>, + role: Role, + mask_frame: bool, +) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite)> { + let mut ws = match role { + Role::Client => { + let stream = ws + .into_inner() + .into_inner() + .downcast::>() + .map_err(|_| anyhow!("cannot downcast websocket client stream"))?; + let mut transport = stream.io.into_inner(); + *transport.read_buf_mut() = stream.read_buf; + WebSocket::after_handshake(transport, role) + } + Role::Server => { + let upgraded = ws.into_inner().into_inner(); + match upgraded.downcast::>>() { + Ok(stream) => { + let transport = TransportStream::TlsSrv(stream.io.into_inner(), stream.read_buf); + WebSocket::after_handshake(transport, role) + } + Err(upgraded) => { + let stream = upgraded + .downcast::>() + .map_err(|_| anyhow!("cannot downcast websocket server stream"))?; + let transport = TransportStream::Plain(stream.io.into_inner(), stream.read_buf); + WebSocket::after_handshake(transport, role) + } + } + } + }; + + ws.set_auto_pong(false); + ws.set_auto_close(false); + ws.set_auto_apply_mask(mask_frame); let (ws_rx, ws_tx) = ws.split(tokio::io::split); let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx); - Ok((ws_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops), response.into_parts().0)) + Ok((ws_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops))) }