Skip to content

Commit

Permalink
fix(#360): data not flushed immediatly on reverse tunnel
Browse files Browse the repository at this point in the history
  • Loading branch information
erebe committed Sep 27, 2024
1 parent b1e0982 commit 937ce59
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 38 deletions.
5 changes: 3 additions & 2 deletions src/tunnel/client/cnx_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::tunnel::client::l4_transport_stream::TransportStream;
use crate::tunnel::client::WsClientConfig;
use async_trait::async_trait;
use bb8::ManageConnection;
use bytes::Bytes;
use std::ops::Deref;
use std::sync::Arc;
use tracing::instrument;
Expand Down Expand Up @@ -58,9 +59,9 @@ impl ManageConnection for WsConnection {

if self.remote_addr.tls().is_some() {
let tls_stream = tls::connect(self, tcp_stream).await?;
Ok(Some(TransportStream::Tls(tls_stream)))
Ok(Some(TransportStream::Tls(tls_stream, Bytes::default())))
} else {
Ok(Some(TransportStream::Plain(tcp_stream)))
Ok(Some(TransportStream::Plain(tcp_stream, Bytes::default())))
}
}

Expand Down
60 changes: 44 additions & 16 deletions src/tunnel/client/l4_transport_stream.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,69 @@
use bytes::{Buf, Bytes};
use std::cmp;
use std::io::{Error, IoSlice};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;

pub enum TransportStream {
Plain(TcpStream),
Tls(TlsStream<TcpStream>),
Plain(TcpStream, Bytes),
Tls(tokio_rustls::client::TlsStream<TcpStream>, Bytes),
TlsSrv(tokio_rustls::server::TlsStream<TcpStream>, Bytes),
}

impl TransportStream {
pub fn read_buf_mut(&mut self) -> &mut Bytes {
match self {
Self::Plain(_, buf) => buf,
Self::Tls(_, buf) => buf,
Self::TlsSrv(_, buf) => buf,
}
}
}

impl AsyncRead for TransportStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
Self::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
Self::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
let this = self.get_mut();

let read_buf = this.read_buf_mut();
if !read_buf.is_empty() {
let copy_len = cmp::min(read_buf.len(), buf.remaining());
buf.put_slice(&read_buf[..copy_len]);
read_buf.advance(copy_len);
return Poll::Ready(Ok(()));
}

match this {
Self::Plain(cnx, _) => Pin::new(cnx).poll_read(cx, buf),
Self::Tls(cnx, _) => Pin::new(cnx).poll_read(cx, buf),
Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_read(cx, buf),
}
}
}

impl AsyncWrite for TransportStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
match self.get_mut() {
Self::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
Self::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),
Self::Plain(cnx, _) => Pin::new(cnx).poll_write(cx, buf),
Self::Tls(cnx, _) => Pin::new(cnx).poll_write(cx, buf),
Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.get_mut() {
Self::Plain(cnx) => Pin::new(cnx).poll_flush(cx),
Self::Tls(cnx) => Pin::new(cnx).poll_flush(cx),
Self::Plain(cnx, _) => Pin::new(cnx).poll_flush(cx),
Self::Tls(cnx, _) => Pin::new(cnx).poll_flush(cx),
Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_flush(cx),
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.get_mut() {
Self::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx),
Self::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx),
Self::Plain(cnx, _) => Pin::new(cnx).poll_shutdown(cx),
Self::Tls(cnx, _) => Pin::new(cnx).poll_shutdown(cx),
Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_shutdown(cx),
}
}

Expand All @@ -47,15 +73,17 @@ impl AsyncWrite for TransportStream {
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, Error>> {
match self.get_mut() {
Self::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
Self::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
Self::Plain(cnx, _) => Pin::new(cnx).poll_write_vectored(cx, bufs),
Self::Tls(cnx, _) => Pin::new(cnx).poll_write_vectored(cx, bufs),
Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_write_vectored(cx, bufs),
}
}

fn is_write_vectored(&self) -> bool {
match &self {
Self::Plain(cnx) => cnx.is_write_vectored(),
Self::Tls(cnx) => cnx.is_write_vectored(),
Self::Plain(cnx, _) => cnx.is_write_vectored(),
Self::Tls(cnx, _) => cnx.is_write_vectored(),
Self::TlsSrv(cnx, _) => cnx.is_write_vectored(),
}
}
}
16 changes: 6 additions & 10 deletions src/tunnel/server/handler_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use crate::restrictions::types::RestrictionsRules;
use crate::tunnel::server::utils::{bad_request, inject_cookie};
use crate::tunnel::server::WsServer;
use crate::tunnel::transport;
use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
use crate::tunnel::transport::websocket::mk_websocket_tunnel;
use bytes::Bytes;
use fastwebsockets::Role;
use http_body_util::combinators::BoxBody;
use http_body_util::Either;
use hyper::body::Incoming;
Expand Down Expand Up @@ -46,31 +47,26 @@ pub(super) async fn ws_server_upgrade(
tokio::spawn(
async move {
let (ws_rx, ws_tx) = match fut.await {
Ok(mut ws) => {
ws.set_auto_pong(false);
ws.set_auto_close(false);
ws.set_auto_apply_mask(mask_frame);
ws.split(tokio::io::split)
}
Ok(ws) => mk_websocket_tunnel(ws, Role::Server, mask_frame)?,
Err(err) => {
error!("Error during http upgrade request: {:?}", err);
return;
return Err(anyhow::Error::from(err));
}
};
let (close_tx, close_rx) = oneshot::channel::<()>();

let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx);
tokio::task::spawn(
transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()),
);

let _ = transport::io::propagate_local_to_remote(
local_rx,
WebsocketTunnelWrite::new(ws_tx, pending_ops),
ws_tx,
close_tx,
server.config.websocket_ping_frequency,
)
.await;
Ok(())
}
.instrument(Span::current()),
);
Expand Down
60 changes: 50 additions & 10 deletions src/tunnel/transport/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use super::io::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
use crate::tunnel::client::l4_transport_stream::TransportStream;
use crate::tunnel::client::WsClient;
use crate::tunnel::transport::headers_from_file;
use crate::tunnel::transport::jwt::{tunnel_to_jwt_token, JWT_HEADER_PREFIX};
use crate::tunnel::RemoteAddr;
use anyhow::{anyhow, Context};
use bytes::{Bytes, BytesMut};
use fastwebsockets::{CloseCode, Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
use fastwebsockets::{CloseCode, Frame, OpCode, Payload, Role, WebSocket, WebSocketRead, WebSocketWrite};
use http_body_util::Empty;
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE};
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
Expand All @@ -22,13 +23,15 @@ use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::Arc;
use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::Notify;
use tokio_rustls::server::TlsStream;
use tracing::trace;
use uuid::Uuid;

pub struct WebsocketTunnelWrite {
inner: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
inner: WebSocketWrite<WriteHalf<TransportStream>>,
buf: BytesMut,
pending_operations: Receiver<Frame<'static>>,
pending_ops_notify: Arc<Notify>,
Expand All @@ -37,7 +40,7 @@ pub struct WebsocketTunnelWrite {

impl WebsocketTunnelWrite {
pub fn new(
ws: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
ws: WebSocketWrite<WriteHalf<TransportStream>>,
(pending_operations, notify): (Receiver<Frame<'static>>, Arc<Notify>),
) -> Self {
Self {
Expand Down Expand Up @@ -146,13 +149,13 @@ impl TunnelWrite for WebsocketTunnelWrite {
}

pub struct WebsocketTunnelRead {
inner: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>,
inner: WebSocketRead<ReadHalf<TransportStream>>,
pending_operations: Sender<Frame<'static>>,
notify_pending_ops: Arc<Notify>,
}

impl WebsocketTunnelRead {
pub fn new(ws: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>) -> (Self, (Receiver<Frame<'static>>, Arc<Notify>)) {
pub fn new(ws: WebSocketRead<ReadHalf<TransportStream>>) -> (Self, (Receiver<Frame<'static>>, Arc<Notify>)) {
let (tx, rx) = tokio::sync::mpsc::channel(10);
let notify = Arc::new(Notify::new());
(
Expand Down Expand Up @@ -278,16 +281,53 @@ pub async fn connect(
})?;
debug!("with HTTP upgrade request {:?}", req);
let transport = pooled_cnx.deref_mut().take().unwrap();
let (mut ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport)
let (ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport)
.await
.with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;

ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
ws.set_auto_close(false);
ws.set_auto_pong(false);
let (ws_rx, ws_tx) = mk_websocket_tunnel(ws, Role::Client, client_cfg.websocket_mask_frame)?;
Ok((ws_rx, ws_tx, response.into_parts().0))
}

pub fn mk_websocket_tunnel(
ws: WebSocket<TokioIo<Upgraded>>,
role: Role,
mask_frame: bool,
) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite)> {
let mut ws = match role {
Role::Client => {
let stream = ws
.into_inner()
.into_inner()
.downcast::<TokioIo<TransportStream>>()
.map_err(|_| anyhow!("cannot downcast websocket client stream"))?;
let mut transport = stream.io.into_inner();
*transport.read_buf_mut() = stream.read_buf;
WebSocket::after_handshake(transport, role)
}
Role::Server => {
let upgraded = ws.into_inner().into_inner();
match upgraded.downcast::<TokioIo<TlsStream<TcpStream>>>() {
Ok(stream) => {
let transport = TransportStream::TlsSrv(stream.io.into_inner(), stream.read_buf);
WebSocket::after_handshake(transport, role)
}
Err(upgraded) => {
let stream = upgraded
.downcast::<TokioIo<TcpStream>>()
.map_err(|_| anyhow!("cannot downcast websocket server stream"))?;
let transport = TransportStream::Plain(stream.io.into_inner(), stream.read_buf);
WebSocket::after_handshake(transport, role)
}
}
}
};

ws.set_auto_pong(false);
ws.set_auto_close(false);
ws.set_auto_apply_mask(mask_frame);
let (ws_rx, ws_tx) = ws.split(tokio::io::split);

let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx);
Ok((ws_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops), response.into_parts().0))
Ok((ws_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops)))
}

0 comments on commit 937ce59

Please sign in to comment.