Merge branch 'handler-channel' into db

This commit is contained in:
Eugene Pankov 2022-09-01 19:24:30 +02:00
commit 0808837f70
No known key found for this signature in database
GPG key ID: 5896FCBBDD1CF4F4
11 changed files with 540 additions and 426 deletions

View file

@ -88,7 +88,7 @@ class Test:
password='123',
)
output = ssh_client.communicate()[0]
output = ssh_client.communicate(timeout=10)[0]
assert b'Warpgate' in output
assert b'Selected target:' in output
assert b'hello\r\n' in output

View file

@ -43,6 +43,7 @@ class Test:
'IdentityFile=ssh-keys/id_ed25519',
'-o',
'PreferredAuthentications=publickey',
# 'sh', '-c', '"ls /bin/sh;sleep 1"',
'ls',
'/bin/sh',
)

View file

@ -1,4 +1,4 @@
use bytes::{Bytes, BytesMut};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use tokio::time::Instant;
use warpgate_db_entities::Recording::RecordingKind;
@ -93,7 +93,7 @@ impl TerminalRecorder {
self.write_item(&TerminalRecordingItem::Data {
time: self.get_time(),
stream,
data: BytesMut::from(data).freeze(),
data: Bytes::from(data.to_vec()),
})
.await
}

View file

@ -3,7 +3,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use bytes::{Bytes, BytesMut};
use bytes::Bytes;
use sea_orm::{ActiveModelTrait, DatabaseConnection, EntityTrait};
use tokio::fs::File;
use tokio::io::{AsyncWriteExt, BufWriter};
@ -102,7 +102,7 @@ impl RecordingWriter {
}
pub async fn write(&mut self, data: &[u8]) -> Result<()> {
let data = BytesMut::from(data).freeze();
let data = Bytes::from(data.to_vec());
self.sender
.send(data.clone())
.await

View file

@ -1,5 +1,5 @@
use anyhow::Result;
use bytes::{Bytes, BytesMut};
use bytes::Bytes;
use russh::client::Msg;
use russh::Channel;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
@ -59,7 +59,7 @@ impl DirectTCPIPChannel {
let bytes: &[u8] = &data;
self.events_tx.send(RCEvent::Output(
self.channel_id,
Bytes::from(BytesMut::from(bytes)),
Bytes::from(bytes.to_vec()),
)).map_err(|_| SshClientError::MpscError)?;
}
Some(russh::ChannelMsg::Close) => {

View file

@ -1,5 +1,5 @@
use anyhow::Result;
use bytes::{Bytes, BytesMut};
use bytes::Bytes;
use russh::client::Msg;
use russh::Channel;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
@ -106,7 +106,7 @@ impl SessionChannel {
debug!("channel data: {bytes:?}");
self.events_tx.send(RCEvent::Output(
self.channel_id,
Bytes::from(BytesMut::from(bytes)),
Bytes::from(bytes.to_vec()),
)).map_err(|_| SshClientError::MpscError)?;
}
Some(russh::ChannelMsg::Close) => {
@ -138,7 +138,7 @@ impl SessionChannel {
let data: &[u8] = &data;
self.events_tx.send(RCEvent::ExtendedData {
channel: self.channel_id,
data: Bytes::from(BytesMut::from(data)),
data: Bytes::from(data.to_vec()),
ext,
}).map_err(|_| SshClientError::MpscError)?;
}

View file

@ -0,0 +1,24 @@
use russh::server::Handle;
use russh::{ChannelId, CryptoVec};
use tokio::sync::mpsc;
/// Sequences data writes and runs them in background to avoid lockups
pub struct ChannelWriter {
tx: mpsc::UnboundedSender<(Handle, ChannelId, CryptoVec)>,
}
impl ChannelWriter {
pub fn new() -> Self {
let (tx, mut rx) = mpsc::unbounded_channel::<(Handle, ChannelId, CryptoVec)>();
tokio::spawn(async move {
while let Some((handle, channel, data)) = rx.recv().await {
let _ = handle.data(channel, data).await;
}
});
ChannelWriter { tx }
}
pub fn write(&self, handle: Handle, channel: ChannelId, data: CryptoVec) {
let _ = self.tx.send((handle, channel, data));
}
}

View file

@ -1,3 +1,4 @@
mod channel_writer;
mod russh_handler;
mod service_output;
mod session;
@ -12,6 +13,7 @@ pub use russh_handler::ServerHandler;
pub use session::ServerSession;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio::sync::mpsc::unbounded_channel;
use tracing::*;
use warpgate_core::{Services, SessionStateInit};
@ -63,18 +65,29 @@ pub async fn run_server(services: Services, address: SocketAddr) -> Result<()> {
let id = server_handle.lock().await.id();
let session =
match ServerSession::new(remote_address, &services, server_handle, session_handle_rx)
.await
{
Ok(session) => session,
Err(error) => {
error!(%error, "Error setting up session");
continue;
}
};
let (event_tx, event_rx) = unbounded_channel();
let handler = ServerHandler { id, session };
let handler = ServerHandler { id, event_tx };
let session = match ServerSession::new(
remote_address,
&services,
server_handle,
session_handle_rx,
event_rx,
)
.await
{
Ok(session) => session,
Err(error) => {
error!(%error, "Error setting up session");
continue;
}
};
tokio::task::Builder::new()
.name(&format!("SSH {id} session"))
.spawn(session);
tokio::task::Builder::new()
.name(&format!("SSH {id} protocol"))

View file

@ -1,22 +1,71 @@
use std::fmt::Debug;
use std::pin::Pin;
use std::sync::Arc;
use bytes::BytesMut;
use bytes::Bytes;
use futures::FutureExt;
use russh::server::{Auth, Session};
use russh::{ChannelId, Pty};
use tokio::sync::Mutex;
use russh::server::{Auth, Handle, Session};
use russh::{ChannelId, Pty, Sig};
use russh_keys::key::PublicKey;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::oneshot;
use tracing::*;
use warpgate_common::{Secret, SessionId};
use super::session::ServerSession;
use crate::common::{PtyRequest, ServerChannelId};
use crate::{DirectTCPIPParams, X11Request};
pub struct HandleWrapper(pub Handle);
impl Debug for HandleWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "HandleWrapper")
}
}
#[derive(Debug)]
pub enum ServerHandlerEvent {
Authenticated(HandleWrapper),
ChannelOpenSession(ServerChannelId, oneshot::Sender<bool>),
SubsystemRequest(ServerChannelId, String, oneshot::Sender<()>),
PtyRequest(ServerChannelId, PtyRequest, oneshot::Sender<()>),
ShellRequest(ServerChannelId, oneshot::Sender<()>),
AuthPublicKey(Secret<String>, PublicKey, oneshot::Sender<Auth>),
AuthPassword(Secret<String>, Secret<String>, oneshot::Sender<Auth>),
AuthKeyboardInteractive(
Secret<String>,
Option<Secret<String>>,
oneshot::Sender<Auth>,
),
Data(ServerChannelId, Bytes, oneshot::Sender<()>),
ExtendedData(ServerChannelId, Bytes, u32, oneshot::Sender<()>),
ChannelClose(ServerChannelId, oneshot::Sender<()>),
ChannelEof(ServerChannelId, oneshot::Sender<()>),
WindowChangeRequest(ServerChannelId, PtyRequest, oneshot::Sender<()>),
Signal(ServerChannelId, Sig, oneshot::Sender<()>),
ExecRequest(ServerChannelId, Bytes, oneshot::Sender<()>),
ChannelOpenDirectTcpIp(ServerChannelId, DirectTCPIPParams, oneshot::Sender<bool>),
EnvRequest(ServerChannelId, String, String, oneshot::Sender<()>),
X11Request(ServerChannelId, X11Request, oneshot::Sender<()>),
Disconnect,
}
pub struct ServerHandler {
pub id: SessionId,
pub session: Arc<Mutex<ServerSession>>,
pub event_tx: UnboundedSender<ServerHandlerEvent>,
}
#[derive(thiserror::Error, Debug)]
pub enum ServerHandlerError {
#[error("channel disconnected")]
ChannelSend,
}
impl ServerHandler {
fn send_event(&self, event: ServerHandlerEvent) -> Result<(), ServerHandlerError> {
self.event_tx
.send(event)
.map_err(|_| ServerHandlerError::ChannelSend)
}
}
impl russh::server::Handler for ServerHandler {
@ -40,16 +89,25 @@ impl russh::server::Handler for ServerHandler {
async { Ok((self, s)) }.boxed()
}
fn channel_open_session(self, channel: ChannelId, mut session: Session) -> Self::FutureBool {
fn auth_succeeded(self, session: Session) -> Self::FutureUnit {
let handle = session.handle();
async {
self.send_event(ServerHandlerEvent::Authenticated(HandleWrapper(handle)))?;
Ok((self, session))
}
.boxed()
}
fn channel_open_session(self, channel: ChannelId, session: Session) -> Self::FutureBool {
async move {
let allowed = {
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._channel_open_session(ServerChannelId(channel), &mut session)
.instrument(span)
.await?
};
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::ChannelOpenSession(
ServerChannelId(channel),
tx,
))?;
let allowed = rx.await.unwrap_or(false);
Ok((self, session, allowed))
}
.boxed()
@ -63,14 +121,15 @@ impl russh::server::Handler for ServerHandler {
) -> Self::FutureUnit {
let name = name.to_string();
async move {
{
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._channel_subsystem_request(ServerChannelId(channel), name)
.instrument(span)
.await?;
}
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::SubsystemRequest(
ServerChannelId(channel),
name,
tx,
))?;
let _ = rx.await;
Ok((self, session))
}
.boxed()
@ -93,25 +152,24 @@ impl russh::server::Handler for ServerHandler {
.take_while(|x| (x.0 as u8) > 0 && (x.0 as u8) < 160)
.map(Clone::clone)
.collect();
async move {
{
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._channel_pty_request(
ServerChannelId(channel),
PtyRequest {
term,
col_width,
row_height,
pix_width,
pix_height,
modes,
},
)
.instrument(span)
.await?;
}
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::PtyRequest(
ServerChannelId(channel),
PtyRequest {
term,
col_width,
row_height,
pix_width,
pix_height,
modes,
},
tx,
))?;
let _ = rx.await;
Ok((self, session))
}
.boxed()
@ -119,70 +177,44 @@ impl russh::server::Handler for ServerHandler {
fn shell_request(self, channel: ChannelId, session: Session) -> Self::FutureUnit {
async move {
{
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._channel_shell_request_nowait(ServerChannelId(channel))
.instrument(span)
.await?;
};
let (tx, rx) = oneshot::channel();
// let reply = {
// let mut this_session = self.session.lock().await;
// let span = this_session.make_logging_span();
// let r = this_session
// ._channel_shell_request_begin(ServerChannelId(channel))
// .instrument(span)
// .await?;
// r
// };
self.send_event(ServerHandlerEvent::ShellRequest(
ServerChannelId(channel),
tx,
))?;
// // Break in ownership to allow event handling while session is started
// reply.await?;
// {
// let mut this_session = self.session.lock().await;
// let span = this_session.make_logging_span();
// this_session
// ._channel_shell_request_finish(ServerChannelId(channel))
// .instrument(span)
// .await?;
// }
let _ = rx.await;
Ok((self, session))
}
.boxed()
}
fn auth_publickey(self, user: &str, key: &russh_keys::key::PublicKey) -> Self::FutureAuth {
let user = user.to_string();
let user = Secret::new(user.to_string());
let key = key.clone();
async move {
let result = {
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._auth_publickey(user, &key)
.instrument(span)
.await
};
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::AuthPublicKey(user, key, tx))?;
let result = rx.await.unwrap_or(Auth::UnsupportedMethod);
Ok((self, result))
}
.boxed()
}
fn auth_password(self, user: &str, password: &str) -> Self::FutureAuth {
let user = user.to_string();
let password = password.to_string();
let user = Secret::new(user.to_string());
let password = Secret::new(password.to_string());
async move {
let result = {
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._auth_password(Secret::new(user), Secret::new(password))
.instrument(span)
.await
};
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::AuthPassword(user, password, tx))?;
let result = rx.await.unwrap_or(Auth::UnsupportedMethod);
Ok((self, result))
}
.boxed()
@ -194,35 +226,35 @@ impl russh::server::Handler for ServerHandler {
_submethods: &str,
response: Option<russh::server::Response>,
) -> Self::FutureAuth {
let user = user.to_string();
let user = Secret::new(user.to_string());
let response = response
.and_then(|mut r| r.next())
.and_then(|b| String::from_utf8(b.to_vec()).ok());
.and_then(|b| String::from_utf8(b.to_vec()).ok())
.map(Secret::new);
async move {
let result = {
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._auth_keyboard_interactive(Secret::new(user), response.map(Secret::new))
.instrument(span)
.await
};
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::AuthKeyboardInteractive(
user, response, tx,
))?;
let result = rx.await.unwrap_or(Auth::UnsupportedMethod);
Ok((self, result))
}
.boxed()
}
fn data(self, channel: ChannelId, data: &[u8], session: Session) -> Self::FutureUnit {
let data = BytesMut::from(data).freeze();
let channel = ServerChannelId(channel);
let data = Bytes::from(data.to_vec());
async move {
{
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._data(ServerChannelId(channel), data)
.instrument(span)
.await?;
}
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::Data(channel, data, tx))?;
let _ = rx.await;
Ok((self, session))
}
.boxed()
@ -235,31 +267,24 @@ impl russh::server::Handler for ServerHandler {
data: &[u8],
session: Session,
) -> Self::FutureUnit {
let data = BytesMut::from(data);
let channel = ServerChannelId(channel);
let data = Bytes::from(data.to_vec());
async move {
{
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._extended_data(ServerChannelId(channel), code, data)
.instrument(span)
.await?;
}
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::ExtendedData(channel, data, code, tx))?;
let _ = rx.await;
Ok((self, session))
}
.boxed()
}
fn channel_close(self, channel: ChannelId, session: Session) -> Self::FutureUnit {
let channel = ServerChannelId(channel);
async move {
{
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._channel_close(ServerChannelId(channel))
.instrument(span)
.await?;
}
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::ChannelClose(channel, tx))?;
let _ = rx.await;
Ok((self, session))
}
.boxed()
@ -275,39 +300,35 @@ impl russh::server::Handler for ServerHandler {
session: Session,
) -> Self::FutureUnit {
async move {
{
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._window_change_request(
ServerChannelId(channel),
PtyRequest {
term: "".to_string(),
col_width,
row_height,
pix_width,
pix_height,
modes: vec![],
},
)
.instrument(span)
.await?;
}
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::WindowChangeRequest(
ServerChannelId(channel),
PtyRequest {
term: "".to_string(),
col_width,
row_height,
pix_width,
pix_height,
modes: vec![],
},
tx,
))?;
let _ = rx.await;
Ok((self, session))
}
.boxed()
}
fn channel_eof(self, channel: ChannelId, session: Session) -> Self::FutureUnit {
let channel = ServerChannelId(channel);
async move {
{
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._channel_eof(ServerChannelId(channel))
.instrument(span)
.await?;
}
let (tx, rx) = oneshot::channel();
self.event_tx
.send(ServerHandlerEvent::ChannelEof(channel, tx))
.map_err(|_| ServerHandlerError::ChannelSend)?;
let _ = rx.await;
Ok((self, session))
}
.boxed()
@ -320,43 +341,28 @@ impl russh::server::Handler for ServerHandler {
session: Session,
) -> Self::FutureUnit {
async move {
{
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._channel_signal(ServerChannelId(channel), signal_name)
.instrument(span)
.await?;
}
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::Signal(
ServerChannelId(channel),
signal_name,
tx,
))?;
let _ = rx.await;
Ok((self, session))
}
.boxed()
}
fn exec_request(self, channel: ChannelId, data: &[u8], session: Session) -> Self::FutureUnit {
let data = BytesMut::from(data);
let data = Bytes::from(data.to_vec());
async move {
let reply = {
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._channel_exec_request_begin(ServerChannelId(channel), data.freeze())
.instrument(span)
.await?
};
// Break in ownership to allow event handling while session is started
reply.await?;
{
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._channel_exec_request_finish(ServerChannelId(channel))
.instrument(span)
.await?
};
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::ExecRequest(
ServerChannelId(channel),
data,
tx,
))?;
let _ = rx.await;
Ok((self, session))
}
.boxed()
@ -372,14 +378,14 @@ impl russh::server::Handler for ServerHandler {
let variable_name = variable_name.to_string();
let variable_value = variable_value.to_string();
async move {
{
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._channel_env_request(ServerChannelId(channel), variable_name, variable_value)
.instrument(span)
.await?
};
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::EnvRequest(
ServerChannelId(channel),
variable_name,
variable_value,
tx,
))?;
let _ = rx.await;
Ok((self, session))
}
.boxed()
@ -392,28 +398,23 @@ impl russh::server::Handler for ServerHandler {
port_to_connect: u32,
originator_address: &str,
originator_port: u32,
mut session: Session,
session: Session,
) -> Self::FutureBool {
let host_to_connect = host_to_connect.to_string();
let originator_address = originator_address.to_string();
async move {
let allowed = {
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._channel_open_direct_tcpip(
ServerChannelId(channel),
DirectTCPIPParams {
host_to_connect,
port_to_connect,
originator_address,
originator_port,
},
&mut session,
)
.instrument(span)
.await?
};
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::ChannelOpenDirectTcpIp(
ServerChannelId(channel),
DirectTCPIPParams {
host_to_connect,
port_to_connect,
originator_address,
originator_port,
},
tx,
))?;
let allowed = rx.await.unwrap_or(false);
Ok((self, session, allowed))
}
.boxed()
@ -431,22 +432,18 @@ impl russh::server::Handler for ServerHandler {
let x11_auth_protocol = x11_auth_protocol.to_string();
let x11_auth_cookie = x11_auth_cookie.to_string();
async move {
{
let mut this_session = self.session.lock().await;
let span = this_session.make_logging_span();
this_session
._channel_x11_request(
ServerChannelId(channel),
X11Request {
single_conection,
x11_auth_protocol,
x11_auth_cookie,
x11_screen_number,
},
)
.instrument(span)
.await?;
}
let (tx, rx) = oneshot::channel();
self.send_event(ServerHandlerEvent::X11Request(
ServerChannelId(channel),
X11Request {
single_conection,
x11_auth_protocol,
x11_auth_cookie,
x11_screen_number,
},
tx,
))?;
let _ = rx.await;
Ok((self, session))
}
.boxed()
@ -470,12 +467,13 @@ impl russh::server::Handler for ServerHandler {
impl Drop for ServerHandler {
fn drop(&mut self) {
debug!("Dropped");
let client = self.session.clone();
tokio::task::Builder::new()
.name(&format!("SSH {} cleanup", self.id))
.spawn(async move {
client.lock().await._disconnect().await;
});
let _ = self.event_tx.send(ServerHandlerEvent::Disconnect);
// let client = self.session.clone();
// tokio::task::Builder::new()
// .name(&format!("SSH {} cleanup", self.id))
// .spawn(async move {
// client.lock().await._disconnect().await;
// });
}
}

View file

@ -2,7 +2,7 @@ use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use ansi_term::Colour;
use bytes::{Bytes, BytesMut};
use bytes::Bytes;
use tokio::sync::{broadcast, mpsc};
pub const ERASE_PROGRESS_SPINNER: &str = "\r \r";
@ -39,7 +39,7 @@ impl ServiceOutput {
#[allow(clippy::indexing_slicing)]
let tick = ticks[tick_index];
let badge = Colour::Black.on(Colour::Blue).paint(format!(" {} Warpgate connecting ", tick)).to_string();
let _ = output_tx.send(BytesMut::from([&ERASE_PROGRESS_SPINNER_BUF[..], badge.as_bytes()].concat().as_slice()).freeze());
let _ = output_tx.send(Bytes::from([&ERASE_PROGRESS_SPINNER_BUF[..], badge.as_bytes()].concat()));
}
}
}
@ -70,7 +70,7 @@ impl ServiceOutput {
self.output_tx.subscribe()
}
fn emit_output(&mut self, output: Bytes) {
pub fn emit_output(&mut self, output: Bytes) {
let _ = self.output_tx.send(output);
}
}

View file

@ -10,9 +10,8 @@ use std::task::Poll;
use ansi_term::Colour;
use anyhow::{Context, Result};
use bimap::BiMap;
use bytes::{Bytes, BytesMut};
use bytes::Bytes;
use futures::Future;
use russh::server::Session;
use russh::{CryptoVec, MethodSet, Sig};
use russh_keys::key::PublicKey;
use russh_keys::PublicKeyBase64;
@ -32,6 +31,8 @@ use warpgate_core::recordings::{
};
use warpgate_core::{authorize_ticket, consume_ticket, Services, WarpgateServerHandle};
use super::channel_writer::ChannelWriter;
use super::russh_handler::ServerHandlerEvent;
use super::service_output::ServiceOutput;
use super::session_handle::SessionHandleCommand;
use crate::compat::ContextExt;
@ -52,6 +53,7 @@ enum TargetSelection {
#[derive(Debug)]
enum Event {
Command(SessionHandleCommand),
ServerHandler(ServerHandlerEvent),
ConsoleInput(Bytes),
ServiceOutput(Bytes),
Client(RCEvent),
@ -84,6 +86,7 @@ pub struct ServerSession {
hub: EventHub<Event>,
event_sender: EventSender<Event>,
service_output: ServiceOutput,
channel_writer: ChannelWriter,
auth_state: Option<Arc<Mutex<AuthState>>>,
keyboard_interactive_state: KeyboardInteractiveState,
}
@ -104,7 +107,8 @@ impl ServerSession {
services: &Services,
server_handle: Arc<Mutex<WarpgateServerHandle>>,
mut session_handle_rx: UnboundedReceiver<SessionHandleCommand>,
) -> Result<Arc<Mutex<Self>>> {
mut handler_event_rx: UnboundedReceiver<ServerHandlerEvent>,
) -> Result<impl Future<Output = Result<()>>> {
let id = server_handle.lock().await.id();
let _span = info_span!("SSH", session=%id);
@ -113,9 +117,11 @@ impl ServerSession {
let mut rc_handles = RemoteClient::create(id, services.clone());
let (hub, event_sender) = EventHub::setup();
let mut event_sub = hub.subscribe(|_| true).await;
let mut event_sub = hub
.subscribe(|e| !matches!(e, Event::ConsoleInput(_)))
.await;
let this = Self {
let mut this = Self {
id,
username: None,
session_handle: None,
@ -136,6 +142,7 @@ impl ServerSession {
hub,
event_sender: event_sender.clone(),
service_output: ServiceOutput::new(),
channel_writer: ChannelWriter::new(),
auth_state: None,
keyboard_interactive_state: KeyboardInteractiveState::None,
};
@ -160,8 +167,6 @@ impl ServerSession {
}
});
let this = Arc::new(Mutex::new(this));
let name = format!("SSH {} session control", id);
tokio::task::Builder::new().name(&name).spawn({
let sender = event_sender.clone();
@ -186,51 +191,29 @@ impl ServerSession {
}
});
let name = format!("SSH {} events", id);
let name = format!("SSH {} server handler events", id);
tokio::task::Builder::new().name(&name).spawn({
let this = Arc::downgrade(&this);
let sender = event_sender.clone();
async move {
loop {
match event_sub.recv().await {
Some(Event::Client(RCEvent::Done)) => break,
Some(Event::Client(e)) => {
debug!(event=?e, "Event");
let Some(this) = this.upgrade() else {
break;
};
let this = &mut this.lock().await;
if let Err(err) = this.handle_remote_event(e).await {
error!("Event handler error: {:?}", err);
break;
}
}
Some(Event::Command(command)) => {
debug!(?command, "Session control");
let Some(this) = this.upgrade() else {
break;
};
let this = &mut this.lock().await;
if let Err(err) = this.handle_session_control(command).await {
error!("Event handler error: {:?}", err);
break;
}
}
Some(Event::ServiceOutput(data)) => {
let Some(this) = this.upgrade() else {
break;
};
let this = &mut this.lock().await;
let _ = this.emit_pty_output(&data).await;
}
Some(Event::ConsoleInput(_)) => (),
None => break,
while let Some(e) = handler_event_rx.recv().await {
if sender.send_once(Event::ServerHandler(e)).await.is_err() {
break;
}
}
debug!("No more events");
}
});
Ok(this)
Ok(async move {
while let Some(event) = event_sub.recv().await {
match event {
Event::Client(RCEvent::Done) => break,
Event::ServerHandler(ServerHandlerEvent::Disconnect) => break,
event => this.handle_event(event).await?,
}
}
debug!("No more events");
Ok::<_, anyhow::Error>(())
})
}
async fn get_auth_state(&mut self, username: &str) -> Result<Arc<Mutex<AuthState>>> {
@ -276,16 +259,14 @@ impl ServerSession {
pub async fn emit_service_message(&mut self, msg: &str) -> Result<()> {
debug!("Service message: {}", msg);
self.emit_pty_output(
format!(
"{}{} {}\r\n",
ERASE_PROGRESS_SPINNER,
Colour::Black.on(Colour::White).paint(" Warpgate "),
msg.replace('\n', "\r\n"),
)
.as_bytes(),
)
.await
self.service_output.emit_output(Bytes::from(format!(
"{}{} {}\r\n",
ERASE_PROGRESS_SPINNER,
Colour::Black.on(Colour::White).paint(" Warpgate "),
msg.replace('\n', "\r\n"),
)));
Ok(())
}
pub async fn emit_pty_output(&mut self, data: &[u8]) -> Result<()> {
@ -293,9 +274,8 @@ impl ServerSession {
for channel in channels {
let channel = self.map_channel_reverse(&channel)?;
if let Some(session) = self.session_handle.clone() {
// .data() will hang and deadlock us if the mpsc capacity is exhausted
let data = CryptoVec::from_slice(data);
tokio::spawn(async move { session.data(channel.0, data).await });
self.channel_writer
.write(session, channel.0, CryptoVec::from_slice(data));
}
}
Ok(())
@ -336,6 +316,181 @@ impl ServerSession {
Ok(())
}
async fn handle_event(&mut self, event: Event) -> Result<()> {
match event {
Event::Client(e) => {
debug!(event=?e, "Event");
let span = self.make_logging_span();
if let Err(err) = self.handle_remote_event(e).instrument(span).await {
error!("Event handler error: {:?}", err);
// break;
}
}
Event::ServerHandler(e) => {
let span = self.make_logging_span();
if let Err(err) = self.handle_server_handler_event(e).instrument(span).await {
error!("Event handler error: {:?}", err);
// break;
}
}
Event::Command(command) => {
debug!(?command, "Session control");
if let Err(err) = self.handle_session_control(command).await {
error!("Event handler error: {:?}", err);
// break;
}
}
Event::ServiceOutput(data) => {
let _ = self.emit_pty_output(&data).await;
}
Event::ConsoleInput(_) => (),
}
Ok(())
}
async fn handle_server_handler_event(&mut self, event: ServerHandlerEvent) -> Result<()> {
match event {
ServerHandlerEvent::Authenticated(handle) => {
self.session_handle = Some(handle.0);
}
ServerHandlerEvent::ChannelOpenSession(server_channel_id, reply) => {
let channel = Uuid::new_v4();
self.channel_map.insert(server_channel_id, channel);
info!(%channel, "Opening session channel");
return match self
.send_command_and_wait(RCCommand::Channel(channel, ChannelOperation::OpenShell))
.await
{
Ok(()) => {
self.all_channels.push(channel);
let _ = reply.send(true);
Ok(())
}
Err(SshClientError::ChannelFailure) => {
let _ = reply.send(false);
Ok(())
}
Err(x) => Err(x.into()),
};
}
ServerHandlerEvent::SubsystemRequest(server_channel_id, name, _) => {
self._channel_subsystem_request(server_channel_id, name)
.await?;
}
ServerHandlerEvent::PtyRequest(server_channel_id, request, _) => {
let channel_id = self.map_channel(&server_channel_id)?;
self.channel_pty_size_map
.insert(channel_id, request.clone());
if let Some(recorder) = self.channel_recorders.get_mut(&channel_id) {
if let Err(error) = recorder
.write_pty_resize(request.col_width, request.row_height)
.await
{
error!(%channel_id, ?error, "Failed to record terminal data");
self.channel_recorders.remove(&channel_id);
}
}
self.send_command_and_wait(RCCommand::Channel(
channel_id,
ChannelOperation::RequestPty(request),
))
.await?;
let _ = self
.session_handle
.as_mut()
.context("Invalid session state")?
.channel_success(server_channel_id.0)
.await;
self.pty_channels.push(channel_id);
}
ServerHandlerEvent::ShellRequest(server_channel_id, _) => {
let channel_id = self.map_channel(&server_channel_id)?;
let _ = self.maybe_connect_remote().await;
let _ = self.send_command(RCCommand::Channel(
channel_id,
ChannelOperation::RequestShell,
));
self.start_terminal_recording(
channel_id,
format!("shell-channel-{}", server_channel_id.0),
)
.await;
info!(%channel_id, "Opening shell");
let _ = self
.session_handle
.as_mut()
.context("Invalid session state")?
.channel_success(server_channel_id.0)
.await;
}
ServerHandlerEvent::AuthPublicKey(username, key, reply) => {
let _ = reply.send(self._auth_publickey(username, key).await);
}
ServerHandlerEvent::AuthPassword(username, password, reply) => {
let _ = reply.send(self._auth_password(username, password).await);
}
ServerHandlerEvent::AuthKeyboardInteractive(username, response, reply) => {
let _ = reply.send(self._auth_keyboard_interactive(username, response).await);
}
ServerHandlerEvent::Data(channel, data, _) => {
self._data(channel, data).await?;
}
ServerHandlerEvent::ExtendedData(channel, data, code, _) => {
self._extended_data(channel, code, data).await?;
}
ServerHandlerEvent::ChannelClose(channel, _) => {
self._channel_close(channel).await?;
}
ServerHandlerEvent::ChannelEof(channel, _) => {
self._channel_eof(channel).await?;
}
ServerHandlerEvent::WindowChangeRequest(channel, request, _) => {
self._window_change_request(channel, request).await?;
}
ServerHandlerEvent::Signal(channel, signal, _) => {
self._channel_signal(channel, signal).await?;
}
ServerHandlerEvent::ExecRequest(channel, data, _) => {
self._channel_exec_request(channel, data).await?;
}
ServerHandlerEvent::ChannelOpenDirectTcpIp(channel, params, reply) => {
let _ = reply.send(self._channel_open_direct_tcpip(channel, params).await?);
}
ServerHandlerEvent::EnvRequest(channel, name, value, _) => {
self._channel_env_request(channel, name, value).await?;
}
ServerHandlerEvent::X11Request(channel, request, _) => {
self._channel_x11_request(channel, request).await?;
}
ServerHandlerEvent::Disconnect => (),
}
Ok(())
}
pub async fn handle_session_control(&mut self, command: SessionHandleCommand) -> Result<()> {
match command {
SessionHandleCommand::Close => {
@ -355,17 +510,13 @@ impl ServerSession {
match &self.rc_state {
RCState::Connected => {
self.service_output.hide_progress().await;
self.emit_pty_output(
format!(
"{}{}\r\n",
ERASE_PROGRESS_SPINNER,
Colour::Black
.on(Colour::Green)
.paint(" ✓ Warpgate connected ")
)
.as_bytes(),
)
.await?;
self.service_output.emit_output(Bytes::from(format!(
"{}{}\r\n",
ERASE_PROGRESS_SPINNER,
Colour::Black
.on(Colour::Green)
.paint(" ✓ Warpgate connected ")
)));
}
RCState::Disconnected => {
self.service_output.hide_progress().await;
@ -406,16 +557,12 @@ impl ServerSession {
.await?;
}
error => {
self.emit_pty_output(
format!(
"{}{} {}\r\n",
ERASE_PROGRESS_SPINNER,
Colour::Black.on(Colour::Red).paint(" Connection failed "),
error
)
.as_bytes(),
)
.await?;
self.service_output.emit_output(Bytes::from(format!(
"{}{} {}\r\n",
ERASE_PROGRESS_SPINNER,
Colour::Black.on(Colour::Red).paint(" Connection failed "),
error
)));
}
}
}
@ -443,14 +590,11 @@ impl ServerSession {
}
let server_channel_id = self.map_channel_reverse(&channel)?;
self.maybe_with_session(|handle| async move {
handle
if let Some(session) = self.session_handle.as_mut() {
let _ = session
.data(server_channel_id.0, CryptoVec::from_slice(&data))
.await
.map_err(|_| ())
.context("failed to send data")
})
.await?;
.await;
}
}
RCEvent::Success(channel) => {
let server_channel_id = self.map_channel_reverse(&channel)?;
@ -646,41 +790,16 @@ impl ServerSession {
Ok(None)
}
pub async fn _channel_open_session(
&mut self,
server_channel_id: ServerChannelId,
session: &mut Session,
) -> Result<bool> {
let channel = Uuid::new_v4();
self.channel_map.insert(server_channel_id, channel);
info!(%channel, "Opening session channel");
self.session_handle = Some(session.handle());
match self
.send_command_and_wait(RCCommand::Channel(channel, ChannelOperation::OpenShell))
.await
{
Ok(()) => {
self.all_channels.push(channel);
Ok(true)
}
Err(SshClientError::ChannelFailure) => Ok(false),
Err(x) => Err(x.into()),
}
}
pub async fn _channel_open_direct_tcpip(
async fn _channel_open_direct_tcpip(
&mut self,
channel: ServerChannelId,
params: DirectTCPIPParams,
session: &mut Session,
) -> Result<bool> {
let uuid = Uuid::new_v4();
self.channel_map.insert(channel, uuid);
info!(%channel, "Opening direct TCP/IP channel from {}:{} to {}:{}", params.originator_address, params.originator_port, params.host_to_connect, params.port_to_connect);
self.session_handle = Some(session.handle());
match self
.send_command_and_wait(RCCommand::Channel(
uuid,
@ -715,39 +834,7 @@ impl ServerSession {
}
}
pub async fn _channel_pty_request(
&mut self,
server_channel_id: ServerChannelId,
request: PtyRequest,
) -> Result<()> {
let channel_id = self.map_channel(&server_channel_id)?;
self.channel_pty_size_map
.insert(channel_id, request.clone());
if let Some(recorder) = self.channel_recorders.get_mut(&channel_id) {
if let Err(error) = recorder
.write_pty_resize(request.col_width, request.row_height)
.await
{
error!(%channel_id, ?error, "Failed to record terminal data");
self.channel_recorders.remove(&channel_id);
}
}
self.send_command_and_wait(RCCommand::Channel(
channel_id,
ChannelOperation::RequestPty(request),
))
.await?;
let _ = self
.session_handle
.as_mut()
.context("Invalid session state")?
.channel_success(server_channel_id.0)
.await;
self.pty_channels.push(channel_id);
Ok(())
}
pub async fn _window_change_request(
async fn _window_change_request(
&mut self,
server_channel_id: ServerChannelId,
request: PtyRequest,
@ -772,11 +859,11 @@ impl ServerSession {
Ok(())
}
pub async fn _channel_exec_request_begin(
async fn _channel_exec_request(
&mut self,
server_channel_id: ServerChannelId,
data: Bytes,
) -> Result<PendingCommand> {
) -> Result<()> {
let channel_id = self.map_channel(&server_channel_id)?;
match std::str::from_utf8(&data) {
Err(e) => {
@ -786,19 +873,13 @@ impl ServerSession {
Ok::<&str, _>(command) => {
debug!(channel=%channel_id, %command, "Requested exec");
let _ = self.maybe_connect_remote().await;
Ok(self.send_command_and_wait(RCCommand::Channel(
let _ = self.send_command(RCCommand::Channel(
channel_id,
ChannelOperation::RequestExec(command.to_string()),
)))
));
}
}
}
pub async fn _channel_exec_request_finish(
&mut self,
server_channel_id: ServerChannelId,
) -> Result<()> {
let channel_id = self.map_channel(&server_channel_id)?;
self.start_terminal_recording(channel_id, format!("exec-channel-{}", server_channel_id.0))
.await;
Ok(())
@ -832,7 +913,7 @@ impl ServerSession {
}
}
pub async fn _channel_x11_request(
async fn _channel_x11_request(
&mut self,
server_channel_id: ServerChannelId,
request: X11Request,
@ -848,7 +929,7 @@ impl ServerSession {
Ok(())
}
pub async fn _channel_env_request(
async fn _channel_env_request(
&mut self,
server_channel_id: ServerChannelId,
name: String,
@ -962,7 +1043,7 @@ impl ServerSession {
Ok(())
}
pub async fn _data(&mut self, server_channel_id: ServerChannelId, data: Bytes) -> Result<()> {
async fn _data(&mut self, server_channel_id: ServerChannelId, data: Bytes) -> Result<()> {
let channel_id = self.map_channel(&server_channel_id)?;
debug!(channel=%server_channel_id.0, ?data, "Data");
if self.rc_state == RCState::Connecting && data.first() == Some(&3) {
@ -999,30 +1080,27 @@ impl ServerSession {
Ok(())
}
pub async fn _extended_data(
async fn _extended_data(
&mut self,
server_channel_id: ServerChannelId,
code: u32,
data: BytesMut,
data: Bytes,
) -> Result<()> {
let channel_id = self.map_channel(&server_channel_id)?;
debug!(channel=%server_channel_id.0, ?data, "Data");
let _ = self.send_command_and_wait(RCCommand::Channel(
channel_id,
ChannelOperation::ExtendedData {
ext: code,
data: data.freeze(),
},
ChannelOperation::ExtendedData { ext: code, data },
));
Ok(())
}
pub async fn _auth_publickey(
async fn _auth_publickey(
&mut self,
ssh_username: String,
key: &PublicKey,
ssh_username: Secret<String>,
key: PublicKey,
) -> russh::server::Auth {
let selector: AuthSelector = (&ssh_username).into();
let selector: AuthSelector = ssh_username.expose_secret().into();
info!(
"Public key auth as {:?} with key {}",
@ -1056,13 +1134,13 @@ impl ServerSession {
}
}
pub async fn _auth_password(
async fn _auth_password(
&mut self,
ssh_username: Secret<String>,
password: Secret<String>,
) -> russh::server::Auth {
let selector: AuthSelector = ssh_username.expose_secret().into();
info!("Password key auth as {:?}", selector);
info!("Password auth as {:?}", selector);
match self
.try_auth(&selector, Some(AuthCredential::Password(password)))
@ -1084,7 +1162,7 @@ impl ServerSession {
}
}
pub async fn _auth_keyboard_interactive(
async fn _auth_keyboard_interactive(
&mut self,
ssh_username: Secret<String>,
response: Option<Secret<String>>,
@ -1311,7 +1389,7 @@ impl ServerSession {
Ok(())
}
pub async fn _channel_close(&mut self, server_channel_id: ServerChannelId) -> Result<()> {
async fn _channel_close(&mut self, server_channel_id: ServerChannelId) -> Result<()> {
let channel_id = self.map_channel(&server_channel_id)?;
debug!(channel=%channel_id, "Closing channel");
self.send_command_and_wait(RCCommand::Channel(channel_id, ChannelOperation::Close))
@ -1319,7 +1397,7 @@ impl ServerSession {
Ok(())
}
pub async fn _channel_eof(&mut self, server_channel_id: ServerChannelId) -> Result<()> {
async fn _channel_eof(&mut self, server_channel_id: ServerChannelId) -> Result<()> {
let channel_id = self.map_channel(&server_channel_id)?;
debug!(channel=%channel_id, "EOF");
self.send_command_and_wait(RCCommand::Channel(channel_id, ChannelOperation::Eof))