added streamlocal-forward support (remote UNIX socket forwarding) (#1243)

This commit is contained in:
Eugene 2025-02-09 15:28:22 +01:00 committed by GitHub
parent 2cdf8babae
commit 55dcd11a17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 250 additions and 54 deletions

4
Cargo.lock generated
View file

@ -3675,9 +3675,9 @@ dependencies = [
[[package]]
name = "russh"
version = "0.50.0"
version = "0.50.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "016fe9bba877904367d9f12c8599948ba2176d01783f78b20348fb138f2337a8"
checksum = "80312b5f2cd37e542093244b369c72b9dbf1ecf4880f8463ba0dbfc949f52fec"
dependencies = [
"aes",
"aes-gcm",

View file

@ -24,8 +24,7 @@ bytes = "1.4"
data-encoding = "2.3"
serde = "1.0"
serde_json = "1.0"
russh = { version = "0.50.0", features = ["des"] }
tracing = "0.1"
russh = { version = "0.50.1", features = ["des"] }
futures = "0.3"
tokio-stream = { version = "0.1.17", features = ["net"] }
tokio-rustls = "0.26"
@ -45,6 +44,7 @@ poem = { version = "3.1", features = [
] }
password-hash = { version = "0.4", features = ["std"] }
delegate = "0.13"
tracing = "0.1"
[profile.release]
lto = true

View file

@ -4,4 +4,5 @@ deny = [
"clippy::expect_used",
"clippy::panic",
"clippy::indexing_slicing",
"dbg_macro",
]

View file

@ -17,11 +17,16 @@ pub struct TrafficRecorder {
}
#[derive(Debug)]
pub struct TrafficConnectionParams {
pub src_addr: Ipv4Addr,
pub src_port: u16,
pub dst_addr: Ipv4Addr,
pub dst_port: u16,
pub enum TrafficConnectionParams {
Tcp {
src_addr: Ipv4Addr,
src_port: u16,
dst_addr: Ipv4Addr,
dst_port: u16,
},
Socket {
socket_path: String,
},
}
impl TrafficRecorder {
@ -136,30 +141,47 @@ impl ConnectionRecorder {
where
F: FnOnce(packet::ip::v4::Builder) -> Result<Bytes>,
{
f(packet::ip::v4::Builder::default()
.protocol(packet::ip::Protocol::Tcp)?
.source(self.params.src_addr)?
.destination(self.params.dst_addr)?)
match self.params {
TrafficConnectionParams::Socket { .. } => f(packet::ip::v4::Builder::default()
.protocol(packet::ip::Protocol::Tcp)?
.source(Ipv4Addr::UNSPECIFIED)?
.destination(Ipv4Addr::BROADCAST)?),
TrafficConnectionParams::Tcp {
src_addr, dst_addr, ..
} => f(packet::ip::v4::Builder::default()
.protocol(packet::ip::Protocol::Tcp)?
.source(src_addr)?
.destination(dst_addr)?),
}
}
fn ip_packet_rx<F>(&self, f: F) -> Result<Bytes>
where
F: FnOnce(packet::ip::v4::Builder) -> Result<Bytes>,
{
f(packet::ip::v4::Builder::default()
.protocol(packet::ip::Protocol::Tcp)?
.source(self.params.dst_addr)?
.destination(self.params.src_addr)?)
match self.params {
TrafficConnectionParams::Socket { .. } => f(packet::ip::v4::Builder::default()
.protocol(packet::ip::Protocol::Tcp)?
.source(Ipv4Addr::BROADCAST)?
.destination(Ipv4Addr::UNSPECIFIED)?),
TrafficConnectionParams::Tcp {
src_addr, dst_addr, ..
} => f(packet::ip::v4::Builder::default()
.protocol(packet::ip::Protocol::Tcp)?
.source(dst_addr)?
.destination(src_addr)?),
}
}
fn tcp_packet_tx<F>(&self, f: F) -> Result<Bytes>
where
F: FnOnce(packet::tcp::Builder) -> Result<Bytes>,
{
self.ip_packet_tx(|b| {
f(b.tcp()?
.source(self.params.src_port)?
.destination(self.params.dst_port)?)
self.ip_packet_tx(|b| match self.params {
TrafficConnectionParams::Socket { .. } => f(b.tcp()?.source(0)?.destination(0)?),
TrafficConnectionParams::Tcp {
src_port, dst_port, ..
} => f(b.tcp()?.source(src_port)?.destination(dst_port)?),
})
}
@ -167,10 +189,11 @@ impl ConnectionRecorder {
where
F: FnOnce(packet::tcp::Builder) -> Result<Bytes>,
{
self.ip_packet_rx(|b| {
f(b.tcp()?
.source(self.params.dst_port)?
.destination(self.params.src_port)?)
self.ip_packet_rx(|b| match self.params {
TrafficConnectionParams::Socket { .. } => f(b.tcp()?.source(0)?.destination(0)?),
TrafficConnectionParams::Tcp {
src_port, dst_port, ..
} => f(b.tcp()?.source(dst_port)?.destination(src_port)?),
})
}

View file

@ -8,13 +8,14 @@ use warpgate_common::{SessionId, TargetSSHOptions};
use warpgate_core::Services;
use crate::known_hosts::{KnownHostValidationResult, KnownHosts};
use crate::{ConnectionError, ForwardedTcpIpParams};
use crate::{ConnectionError, ForwardedStreamlocalParams, ForwardedTcpIpParams};
#[derive(Debug)]
pub enum ClientHandlerEvent {
HostKeyReceived(PublicKey),
HostKeyUnknown(PublicKey, oneshot::Sender<bool>),
ForwardedTcpIp(Channel<Msg>, ForwardedTcpIpParams),
ForwardedStreamlocal(Channel<Msg>, ForwardedStreamlocalParams),
X11(Channel<Msg>, String, u32),
Disconnect,
}
@ -146,6 +147,20 @@ impl russh::client::Handler for ClientHandler {
));
Ok(())
}
async fn server_channel_open_forwarded_streamlocal(
&mut self,
channel: Channel<Msg>,
socket_path: &str,
_session: &mut Session,
) -> Result<(), Self::Error> {
let socket_path = socket_path.to_string();
let _ = self.event_tx.send(ClientHandlerEvent::ForwardedStreamlocal(
channel,
ForwardedStreamlocalParams { socket_path },
));
Ok(())
}
}
impl Drop for ClientHandler {

View file

@ -29,7 +29,7 @@ use warpgate_core::Services;
use self::handler::ClientHandlerEvent;
use super::{ChannelOperation, DirectTCPIPParams};
use crate::client::handler::ClientHandlerError;
use crate::{load_all_usable_private_keys, ForwardedTcpIpParams};
use crate::{load_all_usable_private_keys, ForwardedStreamlocalParams, ForwardedTcpIpParams};
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
@ -91,6 +91,7 @@ pub enum RCEvent {
HostKeyReceived(PublicKey),
HostKeyUnknown(PublicKey, oneshot::Sender<bool>),
ForwardedTcpIp(Uuid, ForwardedTcpIpParams),
ForwardedStreamlocal(Uuid, ForwardedStreamlocalParams),
X11(Uuid, String, u32),
}
@ -102,6 +103,8 @@ pub enum RCCommand {
Channel(Uuid, ChannelOperation),
ForwardTCPIP(String, u32),
CancelTCPIPForward(String, u32),
StreamlocalForward(String),
CancelStreamlocalForward(String),
Disconnect,
}
@ -126,6 +129,7 @@ pub struct RemoteClient {
channel_pipes: Arc<Mutex<HashMap<Uuid, UnboundedSender<ChannelOperation>>>>,
pending_ops: Vec<(Uuid, ChannelOperation)>,
pending_forwards: Vec<(String, u32)>,
pending_streamlocal_forwards: Vec<String>,
state: RCState,
abort_rx: UnboundedReceiver<()>,
inner_event_rx: UnboundedReceiver<InnerEvent>,
@ -155,6 +159,7 @@ impl RemoteClient {
channel_pipes: Arc::new(Mutex::new(HashMap::new())),
pending_ops: vec![],
pending_forwards: vec![],
pending_streamlocal_forwards: vec![],
state: RCState::NotInitialized,
inner_event_rx,
inner_event_tx: inner_event_tx.clone(),
@ -309,6 +314,11 @@ impl RemoteClient {
let id = self.setup_server_initiated_channel(channel).await?;
let _ = self.tx.send(RCEvent::ForwardedTcpIp(id, params));
}
ClientHandlerEvent::ForwardedStreamlocal(channel, params) => {
info!("New forwarded socket connection: {params:?}");
let id = self.setup_server_initiated_channel(channel).await?;
let _ = self.tx.send(RCEvent::ForwardedStreamlocal(id, params));
}
ClientHandlerEvent::X11(channel, originator_address, originator_port) => {
info!("New X11 connection from {originator_address}:{originator_port:?}");
let id = self.setup_server_initiated_channel(channel).await?;
@ -355,10 +365,19 @@ impl RemoteClient {
for (id, op) in ops {
self.apply_channel_op(id, op).await?;
}
let forwards = self.pending_forwards.drain(..).collect::<Vec<_>>();
for (address, port) in forwards {
self.tcpip_forward(address, port).await?;
}
let forwards = self
.pending_streamlocal_forwards
.drain(..)
.collect::<Vec<_>>();
for socket_path in forwards {
self.streamlocal_forward(socket_path).await?;
}
}
Err(e) => {
debug!("Connect error: {}", e);
@ -376,6 +395,12 @@ impl RemoteClient {
RCCommand::CancelTCPIPForward(address, port) => {
self.cancel_tcpip_forward(address, port).await?;
}
RCCommand::StreamlocalForward(socket_path) => {
self.streamlocal_forward(socket_path).await?;
}
RCCommand::CancelStreamlocalForward(socket_path) => {
self.cancel_streamlocal_forward(socket_path).await?;
}
RCCommand::Disconnect => {
self.disconnect().await;
return Ok(true);
@ -635,6 +660,30 @@ impl RemoteClient {
Ok(())
}
async fn streamlocal_forward(&mut self, socket_path: String) -> Result<(), SshClientError> {
if let Some(session) = &self.session {
let mut session = session.lock().await;
session.streamlocal_forward(socket_path).await?;
} else {
self.pending_streamlocal_forwards.push(socket_path);
}
Ok(())
}
async fn cancel_streamlocal_forward(
&mut self,
socket_path: String,
) -> Result<(), SshClientError> {
if let Some(session) = &self.session {
let session = session.lock().await;
session.cancel_streamlocal_forward(socket_path).await?;
} else {
self.pending_streamlocal_forwards
.retain(|x| x != &socket_path);
}
Ok(())
}
async fn disconnect(&mut self) {
if let Some(session) = &mut self.session {
let _ = session

View file

@ -38,6 +38,11 @@ pub struct ForwardedTcpIpParams {
pub originator_port: u32,
}
#[derive(Clone, Debug)]
pub struct ForwardedStreamlocalParams {
pub socket_path: String,
}
#[derive(Clone, Debug)]
pub struct X11Request {
pub single_conection: bool,

View file

@ -47,6 +47,8 @@ pub enum ServerHandlerEvent {
X11Request(ServerChannelId, X11Request, oneshot::Sender<()>),
TcpIpForward(String, u32, oneshot::Sender<bool>),
CancelTcpIpForward(String, u32, oneshot::Sender<bool>),
StreamlocalForward(String, oneshot::Sender<bool>),
CancelStreamlocalForward(String, oneshot::Sender<bool>),
Disconnect,
}
@ -181,7 +183,6 @@ impl russh::server::Handler for ServerHandler {
user: &str,
key: &russh::keys::PublicKey,
) -> Result<Auth, Self::Error> {
dbg!(key);
let user = Secret::new(user.to_string());
let (tx, rx) = oneshot::channel();
@ -478,6 +479,43 @@ impl russh::server::Handler for ServerHandler {
}
Ok(allowed)
}
async fn streamlocal_forward(
&mut self,
socket_path: &str,
session: &mut Session,
) -> Result<bool, Self::Error> {
let socket_path = socket_path.to_string();
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::StreamlocalForward(socket_path, tx))?;
let allowed = rx.await.unwrap_or(false);
if allowed {
session.request_success()
} else {
session.request_failure()
}
Ok(allowed)
}
async fn cancel_streamlocal_forward(
&mut self,
socket_path: &str,
session: &mut Session,
) -> Result<bool, Self::Error> {
let socket_path = socket_path.to_string();
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::CancelStreamlocalForward(
socket_path,
tx,
))?;
let allowed = rx.await.unwrap_or(false);
if allowed {
session.request_success()
} else {
session.request_failure()
}
Ok(allowed)
}
}
impl Drop for ServerHandler {

View file

@ -71,6 +71,21 @@ struct CachedSuccessfulTicketAuth {
username: String,
}
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub enum TrafficRecorderKey {
Tcp(String, u32),
Socket(String),
}
impl TrafficRecorderKey {
pub fn to_name(&self) -> String {
match self {
TrafficRecorderKey::Tcp(addr, port) => format!("{addr}-{port}"),
TrafficRecorderKey::Socket(path) => path.clone().replace("/", "-"),
}
}
}
pub struct ServerSession {
pub id: SessionId,
username: Option<String>,
@ -87,7 +102,7 @@ pub struct ServerSession {
services: Services,
server_handle: Arc<Mutex<WarpgateServerHandle>>,
target: TargetSelection,
traffic_recorders: HashMap<(String, u32), TrafficRecorder>,
traffic_recorders: HashMap<TrafficRecorderKey, TrafficRecorder>,
traffic_connection_recorders: HashMap<Uuid, ConnectionRecorder>,
hub: EventHub<Event>,
event_sender: EventSender<Event>,
@ -556,6 +571,15 @@ impl ServerSession {
let _ = reply.send(true);
}
ServerHandlerEvent::StreamlocalForward(socket_path, reply) => {
self._streamlocal_forward(socket_path).await?;
let _ = reply.send(true);
}
ServerHandlerEvent::CancelStreamlocalForward(socket_path, reply) => {
self._cancel_streamlocal_forward(socket_path).await?;
let _ = reply.send(true);
}
ServerHandlerEvent::Disconnect => (),
}
@ -791,14 +815,16 @@ impl ServerSession {
let recorder = self
.traffic_recorder_for(
&params.originator_address,
params.originator_port,
TrafficRecorderKey::Tcp(
params.originator_address,
params.originator_port,
),
"forwarded-tcpip",
)
.await;
if let Some(recorder) = recorder {
#[allow(clippy::unwrap_used)]
let mut recorder = recorder.connection(TrafficConnectionParams {
let mut recorder = recorder.connection(TrafficConnectionParams::Tcp {
dst_addr: Ipv4Addr::from_str("2.2.2.2").unwrap(),
dst_port: params.connected_port as u16,
src_addr: Ipv4Addr::from_str("1.1.1.1").unwrap(),
@ -811,6 +837,34 @@ impl ServerSession {
}
}
}
RCEvent::ForwardedStreamlocal(id, params) => {
if let Some(session) = &mut self.session_handle {
let server_channel = session
.channel_open_forwarded_streamlocal(params.socket_path.clone())
.await?;
self.channel_map
.insert(ServerChannelId(server_channel.id()), id);
self.all_channels.push(id);
let recorder = self
.traffic_recorder_for(
TrafficRecorderKey::Socket(params.socket_path.clone()),
"forwarded-streamlocal",
)
.await;
if let Some(recorder) = recorder {
#[allow(clippy::unwrap_used)]
let mut recorder = recorder.connection(TrafficConnectionParams::Socket {
socket_path: params.socket_path,
});
if let Err(error) = recorder.write_connection_setup().await {
error!(channel=%id, ?error, "Failed to record connection setup");
}
self.traffic_connection_recorders.insert(id, recorder);
}
}
}
RCEvent::X11(id, originator_address, originator_port) => {
if let Some(session) = &mut self.session_handle {
let server_channel = session
@ -933,14 +987,13 @@ impl ServerSession {
let recorder = self
.traffic_recorder_for(
&params.host_to_connect,
params.port_to_connect,
TrafficRecorderKey::Tcp(params.host_to_connect, params.port_to_connect),
"direct-tcpip",
)
.await;
if let Some(recorder) = recorder {
#[allow(clippy::unwrap_used)]
let mut recorder = recorder.connection(TrafficConnectionParams {
let mut recorder = recorder.connection(TrafficConnectionParams::Tcp {
dst_addr: Ipv4Addr::from_str("2.2.2.2").unwrap(),
dst_port: params.port_to_connect as u16,
src_addr: Ipv4Addr::from_str("1.1.1.1").unwrap(),
@ -1072,29 +1125,27 @@ impl ServerSession {
async fn traffic_recorder_for(
&mut self,
host: &str,
port: u32,
key: TrafficRecorderKey,
tag: &str,
) -> Option<&mut TrafficRecorder> {
let host = host.to_owned();
if let Vacant(e) = self.traffic_recorders.entry((host.clone(), port)) {
if let Vacant(e) = self.traffic_recorders.entry(key.clone()) {
match self
.services
.recordings
.lock()
.await
.start(&self.id, format!("{tag}-{host}-{port}"))
.start(&self.id, format!("{tag}-{}", key.to_name()))
.await
{
Ok(recorder) => {
e.insert(recorder);
}
Err(error) => {
error!(%host, %port, ?error, "Failed to start recording");
error!(?key, ?error, "Failed to start recording");
}
}
}
self.traffic_recorders.get_mut(&(host, port))
self.traffic_recorders.get_mut(&key)
}
pub async fn _channel_subsystem_request(
@ -1180,6 +1231,21 @@ impl ServerSession {
.map_err(anyhow::Error::from)
}
async fn _streamlocal_forward(&mut self, socket_path: String) -> Result<()> {
info!(%socket_path, "Remote UNIX socket forwarding requested");
let _ = self.maybe_connect_remote().await;
self.send_command_and_wait(RCCommand::StreamlocalForward(socket_path))
.await
.map_err(anyhow::Error::from)
}
pub async fn _cancel_streamlocal_forward(&mut self, socket_path: String) -> Result<()> {
info!(%socket_path, "Remote UNIX socket forwarding cancelled");
self.send_command_and_wait(RCCommand::CancelStreamlocalForward(socket_path))
.await
.map_err(anyhow::Error::from)
}
async fn _auth_publickey_offer(
&mut self,
ssh_username: Secret<String>,

View file

@ -165,13 +165,10 @@ impl SsoClient {
})?;
let mut token_verifier = client.id_token_verifier();
dbg!(self.config.additional_trusted_audiences());
if let Some(trusted_audiences) = self.config.additional_trusted_audiences() {
token_verifier = token_verifier.set_other_audience_verifier_fn(|aud| {
dbg!(aud);
trusted_audiences.contains(aud.deref())
});
token_verifier = token_verifier
.set_other_audience_verifier_fn(|aud| trusted_audiences.contains(aud.deref()));
}
let id_token: &CoreIdToken = token_response.id_token().ok_or(SsoError::NotOidc)?;

View file

@ -82,7 +82,10 @@
</div>
<div class="narrow-page">
<Form>
<Form on:submit={e => {
create()
e.preventDefault()
}}>
<!-- svelte-ignore a11y_label_has_associated_control -->
<label class="mb-2">Type</label>
<ButtonGroup class="w-100 mb-3">
@ -108,10 +111,9 @@
<input class="form-control" required bind:value={name} />
</FormGroup>
<AsyncButton
color="primary"
click={create}
>Create target</AsyncButton>
<Button
color="primary"
type="submit"
>Create target</Button>
</Form>
</div>

View file

@ -19,7 +19,7 @@ async function load () {
}
function getTCPDumpURL () {
return `/@warpgate/api/recordings/${recording?.id}/tcpdump`
return `/@warpgate/admin/api/recordings/${recording?.id}/tcpdump`
}
load().catch(async e => {