use std::borrow::Cow; use std::collections::hash_map::Entry::Vacant; use std::collections::{HashMap, HashSet}; use std::net::{Ipv4Addr, SocketAddr}; use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; use std::task::Poll; use ansi_term::Colour; use anyhow::{Context, Result}; use bimap::BiMap; use bytes::Bytes; use futures::{Future, FutureExt}; use russh::{CryptoVec, MethodSet, Sig}; use russh_keys::key::PublicKey; use russh_keys::PublicKeyBase64; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::{broadcast, oneshot, Mutex}; use tracing::*; use uuid::Uuid; use warpgate_common::auth::{AuthCredential, AuthResult, AuthSelector, AuthState, CredentialKind}; use warpgate_common::eventhub::{EventHub, EventSender, EventSubscription}; use warpgate_common::{ Secret, SessionId, SshHostKeyVerificationMode, Target, TargetOptions, TargetSSHOptions, WarpgateError, }; use warpgate_core::recordings::{ self, ConnectionRecorder, TerminalRecorder, TerminalRecordingStreamId, TrafficConnectionParams, TrafficRecorder, }; use warpgate_core::{authorize_ticket, consume_ticket, Services, WarpgateServerHandle}; use super::channel_writer::ChannelWriter; use super::russh_handler::ServerHandlerEvent; use super::service_output::ServiceOutput; use super::session_handle::SessionHandleCommand; use crate::compat::ContextExt; use crate::server::service_output::ERASE_PROGRESS_SPINNER; use crate::{ ChannelOperation, ConnectionError, DirectTCPIPParams, PtyRequest, RCCommand, RCCommandReply, RCEvent, RCState, RemoteClient, ServerChannelId, SshClientError, X11Request, }; #[derive(Clone)] #[allow(clippy::large_enum_variant)] enum TargetSelection { None, NotFound(String), Found(Target, TargetSSHOptions), } #[derive(Debug)] enum Event { Command(SessionHandleCommand), ServerHandler(ServerHandlerEvent), ConsoleInput(Bytes), ServiceOutput(Bytes), Client(RCEvent), } enum KeyboardInteractiveState { None, OtpRequested, WebAuthRequested(broadcast::Receiver), } pub struct ServerSession { pub id: SessionId, username: Option, session_handle: Option, pty_channels: Vec, all_channels: Vec, channel_recorders: HashMap, channel_map: BiMap, channel_pty_size_map: HashMap, rc_tx: UnboundedSender<(RCCommand, Option)>, rc_abort_tx: UnboundedSender<()>, rc_state: RCState, remote_address: SocketAddr, services: Services, server_handle: Arc>, target: TargetSelection, traffic_recorders: HashMap<(String, u32), TrafficRecorder>, traffic_connection_recorders: HashMap, hub: EventHub, event_sender: EventSender, main_event_subscription: EventSubscription, service_output: ServiceOutput, channel_writer: ChannelWriter, auth_state: Option>>, keyboard_interactive_state: KeyboardInteractiveState, } fn session_debug_tag(id: &SessionId, remote_address: &SocketAddr) -> String { format!("[{} - {}]", id, remote_address) } impl std::fmt::Debug for ServerSession { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", session_debug_tag(&self.id, &self.remote_address)) } } impl ServerSession { pub async fn start( remote_address: SocketAddr, services: &Services, server_handle: Arc>, mut session_handle_rx: UnboundedReceiver, mut handler_event_rx: UnboundedReceiver, ) -> Result>> { let id = server_handle.lock().await.id(); let _span = info_span!("SSH", session=%id); let _enter = _span.enter(); let mut rc_handles = RemoteClient::create(id, services.clone()); let (hub, event_sender) = EventHub::setup(); let main_event_subscription = hub .subscribe(|e| !matches!(e, Event::ConsoleInput(_))) .await; let mut this = Self { id, username: None, session_handle: None, pty_channels: vec![], all_channels: vec![], channel_recorders: HashMap::new(), channel_map: BiMap::new(), channel_pty_size_map: HashMap::new(), rc_tx: rc_handles.command_tx.clone(), rc_abort_tx: rc_handles.abort_tx, rc_state: RCState::NotInitialized, remote_address, services: services.clone(), server_handle, target: TargetSelection::None, traffic_recorders: HashMap::new(), traffic_connection_recorders: HashMap::new(), hub, event_sender: event_sender.clone(), main_event_subscription, service_output: ServiceOutput::new(), channel_writer: ChannelWriter::new(), auth_state: None, keyboard_interactive_state: KeyboardInteractiveState::None, }; let mut so_rx = this.service_output.subscribe(); let so_sender = event_sender.clone(); tokio::spawn(async move { loop { match so_rx.recv().await { Ok(data) => { if so_sender .send_once(Event::ServiceOutput(data)) .await .is_err() { break; } } Err(broadcast::error::RecvError::Closed) => break, Err(_) => (), } } }); let name = format!("SSH {} session control", id); tokio::task::Builder::new().name(&name).spawn({ let sender = event_sender.clone(); async move { while let Some(command) = session_handle_rx.recv().await { if sender.send_once(Event::Command(command)).await.is_err() { break; } } } }); let name = format!("SSH {} client events", id); tokio::task::Builder::new().name(&name).spawn({ let sender = event_sender.clone(); async move { while let Some(e) = rc_handles.event_rx.recv().await { if sender.send_once(Event::Client(e)).await.is_err() { break; } } } }); let name = format!("SSH {} server handler events", id); tokio::task::Builder::new().name(&name).spawn({ let sender = event_sender.clone(); async move { while let Some(e) = handler_event_rx.recv().await { if sender.send_once(Event::ServerHandler(e)).await.is_err() { break; } } } }); Ok(async move { while let Some(event) = this.get_next_event().await { this.handle_event(event).await?; } debug!("No more events"); Ok::<_, anyhow::Error>(()) }) } async fn get_next_event(&mut self) -> Option { self.main_event_subscription.recv().await } async fn get_auth_state(&mut self, username: &str) -> Result>> { #[allow(clippy::unwrap_used)] if self.auth_state.is_none() || self.auth_state.as_ref().unwrap().lock().await.username() != username { let state = self .services .auth_state_store .lock() .await .create(username, crate::PROTOCOL_NAME) .await? .1; self.auth_state = Some(state); } #[allow(clippy::unwrap_used)] Ok(self.auth_state.as_ref().map(Clone::clone).unwrap()) } pub fn make_logging_span(&self) -> tracing::Span { match self.username { Some(ref username) => info_span!("SSH", session=%self.id, session_username=%username), None => info_span!("SSH", session=%self.id), } } fn map_channel(&self, ch: &ServerChannelId) -> Result { self.channel_map .get_by_left(ch) .cloned() .ok_or_else(|| anyhow::anyhow!("Channel not known")) } fn map_channel_reverse(&self, ch: &Uuid) -> Result { self.channel_map .get_by_right(ch) .cloned() .ok_or_else(|| anyhow::anyhow!("Channel not known")) } pub async fn emit_service_message(&mut self, msg: &str) -> Result<()> { debug!("Service message: {}", msg); self.service_output.emit_output(Bytes::from(format!( "{}{} {}\r\n", ERASE_PROGRESS_SPINNER, Colour::Black.on(Colour::White).paint(" Warpgate "), msg.replace('\n', "\r\n"), ))); Ok(()) } pub async fn emit_pty_output(&mut self, data: &[u8]) -> Result<()> { let channels = self.pty_channels.clone(); for channel in channels { let channel = self.map_channel_reverse(&channel)?; if let Some(session) = self.session_handle.clone() { self.channel_writer .write(session, channel.0, CryptoVec::from_slice(data)); } } Ok(()) } pub async fn maybe_connect_remote(&mut self) -> Result<()> { match self.target.clone() { TargetSelection::None => { anyhow::bail!("Invalid session state (target not set)") } TargetSelection::NotFound(name) => { self.emit_service_message(&format!("Selected target not found: {name}")) .await?; self.disconnect_server().await; anyhow::bail!("Target not found: {}", name); } TargetSelection::Found(target, ssh_options) => { if self.rc_state == RCState::NotInitialized { self.connect_remote(target, ssh_options).await?; } } } Ok(()) } async fn connect_remote( &mut self, target: Target, ssh_options: TargetSSHOptions, ) -> Result<()> { self.rc_state = RCState::Connecting; self.send_command(RCCommand::Connect(ssh_options)) .map_err(|_| anyhow::anyhow!("cannot send command"))?; self.service_output.show_progress(); self.emit_service_message(&format!("Selected target: {}", target.name)) .await?; Ok(()) } fn handle_event<'a>( &'a mut self, event: Event, ) -> Pin> + Send + 'a>> { async move { match event { Event::Client(RCEvent::Done) => Err(WarpgateError::SessionEnd)?, Event::ServerHandler(ServerHandlerEvent::Disconnect) => { Err(WarpgateError::SessionEnd)? } Event::Client(e) => { debug!(event=?e, "Event"); let span = self.make_logging_span(); if let Err(err) = self.handle_remote_event(e).instrument(span).await { error!("Event handler error: {:?}", err); // break; } } Event::ServerHandler(e) => { let span = self.make_logging_span(); if let Err(err) = self.handle_server_handler_event(e).instrument(span).await { error!("Event handler error: {:?}", err); // break; } } Event::Command(command) => { debug!(?command, "Session control"); if let Err(err) = self.handle_session_control(command).await { error!("Event handler error: {:?}", err); // break; } } Event::ServiceOutput(data) => { let _ = self.emit_pty_output(&data).await; } Event::ConsoleInput(_) => (), } Ok(()) } .boxed() } async fn handle_server_handler_event(&mut self, event: ServerHandlerEvent) -> Result<()> { match event { ServerHandlerEvent::Authenticated(handle) => { self.session_handle = Some(handle.0); } ServerHandlerEvent::ChannelOpenSession(server_channel_id, reply) => { let channel = Uuid::new_v4(); self.channel_map.insert(server_channel_id, channel); info!(%channel, "Opening session channel"); return match self .send_command_and_wait(RCCommand::Channel(channel, ChannelOperation::OpenShell)) .await { Ok(()) => { self.all_channels.push(channel); let _ = reply.send(true); Ok(()) } Err(SshClientError::ChannelFailure) => { let _ = reply.send(false); Ok(()) } Err(x) => Err(x.into()), }; } ServerHandlerEvent::SubsystemRequest(server_channel_id, name, _) => { self._channel_subsystem_request(server_channel_id, name) .await?; } ServerHandlerEvent::PtyRequest(server_channel_id, request, _) => { let channel_id = self.map_channel(&server_channel_id)?; self.channel_pty_size_map .insert(channel_id, request.clone()); if let Some(recorder) = self.channel_recorders.get_mut(&channel_id) { if let Err(error) = recorder .write_pty_resize(request.col_width, request.row_height) .await { error!(%channel_id, ?error, "Failed to record terminal data"); self.channel_recorders.remove(&channel_id); } } self.send_command_and_wait(RCCommand::Channel( channel_id, ChannelOperation::RequestPty(request), )) .await?; let _ = self .session_handle .as_mut() .context("Invalid session state")? .channel_success(server_channel_id.0) .await; self.pty_channels.push(channel_id); } ServerHandlerEvent::ShellRequest(server_channel_id, _) => { let channel_id = self.map_channel(&server_channel_id)?; let _ = self.maybe_connect_remote().await; let _ = self.send_command(RCCommand::Channel( channel_id, ChannelOperation::RequestShell, )); self.start_terminal_recording( channel_id, format!("shell-channel-{}", server_channel_id.0), ) .await; info!(%channel_id, "Opening shell"); let _ = self .session_handle .as_mut() .context("Invalid session state")? .channel_success(server_channel_id.0) .await; } ServerHandlerEvent::AuthPublicKey(username, key, reply) => { let _ = reply.send(self._auth_publickey(username, key).await); } ServerHandlerEvent::AuthPassword(username, password, reply) => { let _ = reply.send(self._auth_password(username, password).await); } ServerHandlerEvent::AuthKeyboardInteractive(username, response, reply) => { let _ = reply.send(self._auth_keyboard_interactive(username, response).await); } ServerHandlerEvent::Data(channel, data, _) => { self._data(channel, data).await?; } ServerHandlerEvent::ExtendedData(channel, data, code, _) => { self._extended_data(channel, code, data).await?; } ServerHandlerEvent::ChannelClose(channel, _) => { self._channel_close(channel).await?; } ServerHandlerEvent::ChannelEof(channel, _) => { self._channel_eof(channel).await?; } ServerHandlerEvent::WindowChangeRequest(channel, request, _) => { self._window_change_request(channel, request).await?; } ServerHandlerEvent::Signal(channel, signal, _) => { self._channel_signal(channel, signal).await?; } ServerHandlerEvent::ExecRequest(channel, data, _) => { self._channel_exec_request(channel, data).await?; } ServerHandlerEvent::ChannelOpenDirectTcpIp(channel, params, reply) => { let _ = reply.send(self._channel_open_direct_tcpip(channel, params).await?); } ServerHandlerEvent::EnvRequest(channel, name, value, _) => { self._channel_env_request(channel, name, value).await?; } ServerHandlerEvent::X11Request(channel, request, _) => { self._channel_x11_request(channel, request).await?; } ServerHandlerEvent::TcpIpForward(address, port, reply) => { self._tcpip_forward(address, port).await?; let _ = reply.send(true); } ServerHandlerEvent::CancelTcpIpForward(address, port, reply) => { self._cancel_tcpip_forward(address, port).await?; let _ = reply.send(true); } ServerHandlerEvent::Disconnect => (), } Ok(()) } pub async fn handle_session_control(&mut self, command: SessionHandleCommand) -> Result<()> { match command { SessionHandleCommand::Close => { let _ = self.emit_service_message("Session closed by admin").await; info!("Session closed by admin"); self.request_disconnect().await; self.disconnect_server().await; } } Ok(()) } pub async fn handle_remote_event(&mut self, event: RCEvent) -> Result<()> { match event { RCEvent::State(state) => { self.rc_state = state; match &self.rc_state { RCState::Connected => { self.service_output.hide_progress().await; self.service_output.emit_output(Bytes::from(format!( "{}{}\r\n", ERASE_PROGRESS_SPINNER, Colour::Black .on(Colour::Green) .paint(" ✓ Warpgate connected ") ))); } RCState::Disconnected => { self.service_output.hide_progress().await; self.disconnect_server().await; } _ => {} } } RCEvent::ConnectionError(error) => { self.service_output.hide_progress().await; match error { ConnectionError::HostKeyMismatch { received_key_type, received_key_base64, known_key_type, known_key_base64, } => { let msg = format!( concat!( "Host key doesn't match the stored one.\n", "Stored key ({}): {}\n", "Received key ({}): {}", ), known_key_type, known_key_base64, received_key_type, received_key_base64 ); self.emit_service_message(&msg).await?; self.emit_service_message( "If you know that the key is correct (e.g. it has been changed),", ) .await?; self.emit_service_message( "you can remove the old key in the Warpgate management UI and try again", ) .await?; } error => { self.service_output.emit_output(Bytes::from(format!( "{}{} {}\r\n", ERASE_PROGRESS_SPINNER, Colour::Black.on(Colour::Red).paint(" Connection failed "), error ))); } } } RCEvent::Error(e) => { self.service_output.hide_progress().await; let _ = self.emit_service_message(&format!("Error: {e}")).await; self.disconnect_server().await; } RCEvent::Output(channel, data) => { if let Some(recorder) = self.channel_recorders.get_mut(&channel) { if let Err(error) = recorder .write(TerminalRecordingStreamId::Output, &data) .await { error!(%channel, ?error, "Failed to record terminal data"); self.channel_recorders.remove(&channel); } } if let Some(recorder) = self.traffic_connection_recorders.get_mut(&channel) { if let Err(error) = recorder.write_rx(&data).await { error!(%channel, ?error, "Failed to record traffic data"); self.traffic_connection_recorders.remove(&channel); } } let server_channel_id = self.map_channel_reverse(&channel)?; if let Some(session) = self.session_handle.as_mut() { let _ = session .data(server_channel_id.0, CryptoVec::from_slice(&data)) .await; } } RCEvent::Success(channel) => { let server_channel_id = self.map_channel_reverse(&channel)?; self.maybe_with_session(|handle| async move { handle .channel_success(server_channel_id.0) .await .context("failed to send data") }) .await?; } RCEvent::ChannelFailure(channel) => { let server_channel_id = self.map_channel_reverse(&channel)?; self.maybe_with_session(|handle| async move { handle .channel_failure(server_channel_id.0) .await .context("failed to send data") }) .await?; } RCEvent::Close(channel) => { let server_channel_id = self.map_channel_reverse(&channel)?; let _ = self .maybe_with_session(|handle| async move { handle .close(server_channel_id.0) .await .context("failed to close ch") }) .await; } RCEvent::Eof(channel) => { let server_channel_id = self.map_channel_reverse(&channel)?; self.maybe_with_session(|handle| async move { handle .eof(server_channel_id.0) .await .context("failed to send eof") }) .await?; } RCEvent::ExitStatus(channel, code) => { let server_channel_id = self.map_channel_reverse(&channel)?; self.maybe_with_session(|handle| async move { handle .exit_status_request(server_channel_id.0, code) .await .context("failed to send exit status") }) .await?; } RCEvent::ExitSignal { channel, signal_name, core_dumped, error_message, lang_tag, } => { let server_channel_id = self.map_channel_reverse(&channel)?; self.maybe_with_session(|handle| async move { handle .exit_signal_request( server_channel_id.0, signal_name, core_dumped, error_message, lang_tag, ) .await .context("failed to send exit status")?; Ok(()) }) .await?; } RCEvent::Done => {} RCEvent::ExtendedData { channel, data, ext } => { if let Some(recorder) = self.channel_recorders.get_mut(&channel) { if let Err(error) = recorder .write(TerminalRecordingStreamId::Error, &data) .await { error!(%channel, ?error, "Failed to record session data"); self.channel_recorders.remove(&channel); } } let server_channel_id = self.map_channel_reverse(&channel)?; self.maybe_with_session(|handle| async move { handle .extended_data(server_channel_id.0, ext, CryptoVec::from_slice(&data)) .await .map_err(|_| ()) .context("failed to send extended data")?; Ok(()) }) .await?; } RCEvent::HostKeyReceived(key) => { self.emit_service_message(&format!( "Host key ({}): {}", key.name(), key.public_key_base64() )) .await?; } RCEvent::HostKeyUnknown(key, reply) => { self.handle_unknown_host_key(key, reply).await?; } RCEvent::ForwardedTcpIp(id, params) => { if let Some(session) = &mut self.session_handle { let server_channel = session .channel_open_forwarded_tcpip( params.connected_address, params.connected_port, params.originator_address.clone(), params.originator_port, ) .await?; self.channel_map .insert(ServerChannelId(server_channel.id()), id); self.all_channels.push(id); let recorder = self .traffic_recorder_for( ¶ms.originator_address, params.originator_port, "forwarded-tcpip", ) .await; if let Some(recorder) = recorder { #[allow(clippy::unwrap_used)] let mut recorder = recorder.connection(TrafficConnectionParams { dst_addr: Ipv4Addr::from_str("2.2.2.2").unwrap(), dst_port: params.connected_port as u16, src_addr: Ipv4Addr::from_str("1.1.1.1").unwrap(), src_port: params.originator_port as u16, }); if let Err(error) = recorder.write_connection_setup().await { error!(channel=%id, ?error, "Failed to record connection setup"); } self.traffic_connection_recorders.insert(id, recorder); } } } RCEvent::X11(id, originator_address, originator_port) => { if let Some(session) = &mut self.session_handle { let server_channel = session .channel_open_x11(originator_address, originator_port) .await?; self.channel_map .insert(ServerChannelId(server_channel.id()), id); self.all_channels.push(id); } } } Ok(()) } async fn handle_unknown_host_key( &mut self, key: PublicKey, reply: oneshot::Sender, ) -> Result<()> { self.service_output.hide_progress().await; let mode = self .services .config .lock() .await .store .ssh .host_key_verification; if mode == SshHostKeyVerificationMode::AutoAccept { let _ = reply.send(true); info!("Accepted untrusted host key (auto-accept is enabled)"); return Ok(()); } if mode == SshHostKeyVerificationMode::AutoReject { let _ = reply.send(false); info!("Rejected untrusted host key (auto-reject is enabled)"); return Ok(()); } if self.pty_channels.is_empty() { warn!("Target host key is not trusted, but there is no active PTY channel to show the trust prompt on."); warn!( "Connect to this target with an interactive session once to accept the host key." ); self.request_disconnect().await; anyhow::bail!("No PTY channel to show an interactive prompt on") } self.emit_service_message(&format!( "There is no trusted {} key for this host.", key.name() )) .await?; self.emit_service_message("Trust this key? (y/n)").await?; let mut sub = self .hub .subscribe(|e| matches!(e, Event::ConsoleInput(_))) .await; let mut service_output = self.service_output.clone(); tokio::spawn(async move { loop { match sub.recv().await { Some(Event::ConsoleInput(data)) => { if data == "y".as_bytes() { let _ = reply.send(true); break; } else if data == "n".as_bytes() { let _ = reply.send(false); break; } } None => break, _ => (), } } service_output.show_progress(); }); Ok(()) } async fn maybe_with_session<'a, FN, FT, R>(&'a mut self, f: FN) -> Result> where FN: FnOnce(&'a mut russh::server::Handle) -> FT + 'a, FT: futures::Future>, { if let Some(handle) = &mut self.session_handle { return Ok(Some(f(handle).await?)); } Ok(None) } async fn _channel_open_direct_tcpip( &mut self, channel: ServerChannelId, params: DirectTCPIPParams, ) -> Result { let uuid = Uuid::new_v4(); self.channel_map.insert(channel, uuid); info!(%channel, "Opening direct TCP/IP channel from {}:{} to {}:{}", params.originator_address, params.originator_port, params.host_to_connect, params.port_to_connect); let _ = self.maybe_connect_remote().await; match self .send_command_and_wait(RCCommand::Channel( uuid, ChannelOperation::OpenDirectTCPIP(params.clone()), )) .await { Ok(()) => { self.all_channels.push(uuid); let recorder = self .traffic_recorder_for( ¶ms.host_to_connect, params.port_to_connect, "direct-tcpip", ) .await; if let Some(recorder) = recorder { #[allow(clippy::unwrap_used)] let mut recorder = recorder.connection(TrafficConnectionParams { dst_addr: Ipv4Addr::from_str("2.2.2.2").unwrap(), dst_port: params.port_to_connect as u16, src_addr: Ipv4Addr::from_str("1.1.1.1").unwrap(), src_port: params.originator_port as u16, }); if let Err(error) = recorder.write_connection_setup().await { error!(%channel, ?error, "Failed to record connection setup"); } self.traffic_connection_recorders.insert(uuid, recorder); } Ok(true) } Err(SshClientError::ChannelFailure) => Ok(false), Err(x) => Err(x.into()), } } async fn _window_change_request( &mut self, server_channel_id: ServerChannelId, request: PtyRequest, ) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; self.channel_pty_size_map .insert(channel_id, request.clone()); if let Some(recorder) = self.channel_recorders.get_mut(&channel_id) { if let Err(error) = recorder .write_pty_resize(request.col_width, request.row_height) .await { error!(%channel_id, ?error, "Failed to record terminal data"); self.channel_recorders.remove(&channel_id); } } self.send_command_and_wait(RCCommand::Channel( channel_id, ChannelOperation::ResizePty(request), )) .await?; Ok(()) } async fn _channel_exec_request( &mut self, server_channel_id: ServerChannelId, data: Bytes, ) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; match std::str::from_utf8(&data) { Err(e) => { error!(channel=%channel_id, ?data, "Requested exec - invalid UTF-8"); anyhow::bail!(e) } Ok::<&str, _>(command) => { debug!(channel=%channel_id, %command, "Requested exec"); let _ = self.maybe_connect_remote().await; let _ = self.send_command(RCCommand::Channel( channel_id, ChannelOperation::RequestExec(command.to_string()), )); } } self.start_terminal_recording(channel_id, format!("exec-channel-{}", server_channel_id.0)) .await; Ok(()) } async fn start_terminal_recording(&mut self, channel_id: Uuid, name: String) { match async { let mut recorder = self .services .recordings .lock() .await .start::(&self.id, name) .await?; if let Some(request) = self.channel_pty_size_map.get(&channel_id) { recorder .write_pty_resize(request.col_width, request.row_height) .await?; } Ok::<_, recordings::Error>(recorder) } .await { Ok(recorder) => { self.channel_recorders.insert(channel_id, recorder); } Err(error) => match error { recordings::Error::Disabled => (), error => error!(channel=%channel_id, ?error, "Failed to start recording"), }, } } async fn _channel_x11_request( &mut self, server_channel_id: ServerChannelId, request: X11Request, ) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; debug!(channel=%channel_id, "Requested X11"); let _ = self.maybe_connect_remote().await; self.send_command_and_wait(RCCommand::Channel( channel_id, ChannelOperation::RequestX11(request), )) .await?; Ok(()) } async fn _channel_env_request( &mut self, server_channel_id: ServerChannelId, name: String, value: String, ) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; debug!(channel=%channel_id, %name, %value, "Environment"); self.send_command_and_wait(RCCommand::Channel( channel_id, ChannelOperation::RequestEnv(name, value), )) .await?; Ok(()) } async fn traffic_recorder_for( &mut self, host: &str, port: u32, tag: &str, ) -> Option<&mut TrafficRecorder> { let host = host.to_owned(); if let Vacant(e) = self.traffic_recorders.entry((host.clone(), port)) { match self .services .recordings .lock() .await .start(&self.id, format!("{tag}-{host}-{port}")) .await { Ok(recorder) => { e.insert(recorder); } Err(error) => { error!(%host, %port, ?error, "Failed to start recording"); } } } self.traffic_recorders.get_mut(&(host, port)) } pub async fn _channel_shell_request_nowait( &mut self, server_channel_id: ServerChannelId, ) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; let _ = self.maybe_connect_remote().await; let _ = self.send_command(RCCommand::Channel( channel_id, ChannelOperation::RequestShell, )); self.start_terminal_recording(channel_id, format!("shell-channel-{}", server_channel_id.0)) .await; info!(%channel_id, "Opening shell"); let _ = self .session_handle .as_mut() .context("Invalid session state")? .channel_success(server_channel_id.0) .await; Ok(()) } pub async fn _channel_shell_request_begin( &mut self, server_channel_id: ServerChannelId, ) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; let _ = self.maybe_connect_remote().await; self.send_command_and_wait(RCCommand::Channel( channel_id, ChannelOperation::RequestShell, )) .await .map_err(anyhow::Error::from) } pub async fn _channel_shell_request_finish( &mut self, server_channel_id: ServerChannelId, ) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; self.start_terminal_recording(channel_id, format!("shell-channel-{}", server_channel_id.0)) .await; info!(%channel_id, "Opening shell"); let session = self .session_handle .clone() .context("Invalid session state")?; tokio::spawn(async move { session.channel_success(server_channel_id.0).await }); Ok(()) } pub async fn _channel_subsystem_request( &mut self, server_channel_id: ServerChannelId, name: String, ) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; info!(channel=%channel_id, "Requesting subsystem {}", &name); let _ = self.maybe_connect_remote().await; self.send_command_and_wait(RCCommand::Channel( channel_id, ChannelOperation::RequestSubsystem(name), )) .await?; Ok(()) } async fn _data(&mut self, server_channel_id: ServerChannelId, data: Bytes) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; debug!(channel=%server_channel_id.0, ?data, "Data"); if self.rc_state == RCState::Connecting && data.first() == Some(&3) { info!(channel=%channel_id, "User requested connection abort (Ctrl-C)"); self.request_disconnect().await; return Ok(()); } if let Some(recorder) = self.channel_recorders.get_mut(&channel_id) { if let Err(error) = recorder .write(TerminalRecordingStreamId::Input, &data) .await { error!(channel=%channel_id, ?error, "Failed to record terminal data"); self.channel_recorders.remove(&channel_id); } } if let Some(recorder) = self.traffic_connection_recorders.get_mut(&channel_id) { if let Err(error) = recorder.write_tx(&data).await { error!(channel=%channel_id, ?error, "Failed to record traffic data"); self.traffic_connection_recorders.remove(&channel_id); } } if self.pty_channels.contains(&channel_id) { let _ = self .event_sender .send_once(Event::ConsoleInput(data.clone())) .await; } let _ = self.send_command(RCCommand::Channel(channel_id, ChannelOperation::Data(data))); Ok(()) } async fn _extended_data( &mut self, server_channel_id: ServerChannelId, code: u32, data: Bytes, ) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; debug!(channel=%server_channel_id.0, ?data, "Data"); let _ = self.send_command_and_wait(RCCommand::Channel( channel_id, ChannelOperation::ExtendedData { ext: code, data }, )); Ok(()) } async fn _tcpip_forward(&mut self, address: String, port: u32) -> Result<()> { info!(%address, %port, "Remote port forwarding requested"); let _ = self.maybe_connect_remote().await; self.send_command_and_wait(RCCommand::ForwardTCPIP(address, port)) .await .map_err(anyhow::Error::from) } pub async fn _cancel_tcpip_forward(&mut self, address: String, port: u32) -> Result<()> { info!(%address, %port, "Remote port forwarding cancelled"); self.send_command_and_wait(RCCommand::CancelTCPIPForward(address, port)) .await .map_err(anyhow::Error::from) } async fn _auth_publickey( &mut self, ssh_username: Secret, key: PublicKey, ) -> russh::server::Auth { let selector: AuthSelector = ssh_username.expose_secret().into(); info!( "Public key auth as {:?} with key {}", selector, key.public_key_base64() ); match self .try_auth( &selector, Some(AuthCredential::PublicKey { kind: key.name().to_string(), public_key_bytes: Bytes::from(key.public_key_bytes()), }), ) .await { Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept, Ok(AuthResult::Rejected) => russh::server::Auth::Reject { proceed_with_methods: Some(MethodSet::all()), }, Ok(AuthResult::Need(kinds)) => russh::server::Auth::Reject { proceed_with_methods: Some(self.get_remaining_auth_methods(kinds)), }, Err(error) => { error!(?error, "Failed to verify credentials"); russh::server::Auth::Reject { proceed_with_methods: None, } } } } async fn _auth_password( &mut self, ssh_username: Secret, password: Secret, ) -> russh::server::Auth { let selector: AuthSelector = ssh_username.expose_secret().into(); info!("Password auth as {:?}", selector); match self .try_auth(&selector, Some(AuthCredential::Password(password))) .await { Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept, Ok(AuthResult::Rejected) => russh::server::Auth::Reject { proceed_with_methods: None, }, Ok(AuthResult::Need(_)) => russh::server::Auth::Reject { proceed_with_methods: None, }, Err(error) => { error!(?error, "Failed to verify credentials"); russh::server::Auth::Reject { proceed_with_methods: None, } } } } async fn _auth_keyboard_interactive( &mut self, ssh_username: Secret, response: Option>, ) -> russh::server::Auth { let selector: AuthSelector = ssh_username.expose_secret().into(); info!("Keyboard-interactive auth as {:?}", selector); let cred; match &mut self.keyboard_interactive_state { KeyboardInteractiveState::None => { cred = None; } KeyboardInteractiveState::OtpRequested => { cred = response.map(AuthCredential::Otp); } KeyboardInteractiveState::WebAuthRequested(event) => { cred = None; let _ = event.recv().await; // the auth state has been updated by now } } self.keyboard_interactive_state = KeyboardInteractiveState::None; match self.try_auth(&selector, cred).await { Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept, Ok(AuthResult::Rejected) => russh::server::Auth::Reject { proceed_with_methods: None, }, Ok(AuthResult::Need(kinds)) => { if kinds.contains(&CredentialKind::Totp) { self.keyboard_interactive_state = KeyboardInteractiveState::OtpRequested; russh::server::Auth::Partial { name: Cow::Borrowed("Two-factor authentication"), instructions: Cow::Borrowed(""), prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]), } } else if kinds.contains(&CredentialKind::WebUserApproval) { let Some(auth_state) = self.auth_state.as_ref() else { return russh::server::Auth::Reject { proceed_with_methods: None}; }; let auth_state_id = *auth_state.lock().await.id(); let event = self .services .auth_state_store .lock() .await .subscribe(auth_state_id); self.keyboard_interactive_state = KeyboardInteractiveState::WebAuthRequested(event); let mut login_url = match self .services .config .lock() .await .construct_external_url(None) { Ok(url) => url, Err(error) => { error!(?error, "Failed to construct external URL"); return russh::server::Auth::Reject { proceed_with_methods: None, }; } }; login_url.set_path("@warpgate"); login_url.set_fragment(Some(&format!("/login/{auth_state_id}"))); russh::server::Auth::Partial { name: Cow::Owned(format!( concat!( "----------------------------------------------------------------\n", "Warpgate authentication: please open {} in your browser\n", "----------------------------------------------------------------\n" ), login_url )), instructions: Cow::Borrowed(""), prompts: Cow::Owned(vec![(Cow::Borrowed("Press Enter when done: "), true)]), } } else { russh::server::Auth::Reject { proceed_with_methods: None, } } } Err(error) => { error!(?error, "Failed to verify credentials"); russh::server::Auth::Reject { proceed_with_methods: None, } } } } fn get_remaining_auth_methods(&self, kinds: HashSet) -> MethodSet { let mut m = MethodSet::empty(); for kind in kinds { match kind { CredentialKind::Password => m.insert(MethodSet::PASSWORD), CredentialKind::Totp => m.insert(MethodSet::KEYBOARD_INTERACTIVE), CredentialKind::WebUserApproval => m.insert(MethodSet::KEYBOARD_INTERACTIVE), CredentialKind::PublicKey => m.insert(MethodSet::PUBLICKEY), CredentialKind::Sso => m.insert(MethodSet::KEYBOARD_INTERACTIVE), } } m } async fn try_auth( &mut self, selector: &AuthSelector, credential: Option, ) -> Result { match selector { AuthSelector::User { username, target_name, } => { let cp = self.services.config_provider.clone(); let state_arc = self.get_auth_state(username).await?; let mut state = state_arc.lock().await; if let Some(credential) = credential { if cp .lock() .await .validate_credential(username, &credential) .await? { state.add_valid_credential(credential); } } let user_auth_result = state.verify(); match user_auth_result { AuthResult::Accepted { username } => { self.services .auth_state_store .lock() .await .complete(state.id()) .await; let target_auth_result = { self.services .config_provider .lock() .await .authorize_target(&username, target_name) .await? }; if !target_auth_result { warn!( "Target {} not authorized for user {}", target_name, username ); return Ok(AuthResult::Rejected); } self._auth_accept(&username, target_name).await?; Ok(AuthResult::Accepted { username }) } x => Ok(x), } } AuthSelector::Ticket { secret } => { match authorize_ticket(&self.services.db, secret).await? { Some(ticket) => { info!("Authorized for {} with a ticket", ticket.target); consume_ticket(&self.services.db, &ticket.id).await?; self._auth_accept(&ticket.username, &ticket.target).await?; Ok(AuthResult::Accepted { username: ticket.username.clone(), }) } None => Ok(AuthResult::Rejected), } } } } async fn _auth_accept( &mut self, username: &str, target_name: &str, ) -> Result<(), WarpgateError> { info!(%username, "Authenticated"); let _ = self .server_handle .lock() .await .set_username(username.to_string()) .await; self.username = Some(username.to_string()); let target = { self.services .config_provider .lock() .await .list_targets() .await? .iter() .filter_map(|t| match t.options { TargetOptions::Ssh(ref options) => Some((t, options)), _ => None, }) .find(|(t, _)| t.name == target_name) .map(|(t, opt)| (t.clone(), opt.clone())) }; let Some((target, mut ssh_options)) = target else { self.target = TargetSelection::NotFound(target_name.to_string()); warn!("Selected target not found"); return Ok(()); }; // Forward username from the authenticated user to the target, if target has no username if ssh_options.username.is_empty() { ssh_options.username = username.to_string(); } let _ = self.server_handle.lock().await.set_target(&target).await; self.target = TargetSelection::Found(target, ssh_options); Ok(()) } async fn _channel_close(&mut self, server_channel_id: ServerChannelId) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; debug!(channel=%channel_id, "Closing channel"); self.send_command_and_wait(RCCommand::Channel(channel_id, ChannelOperation::Close)) .await?; Ok(()) } async fn _channel_eof(&mut self, server_channel_id: ServerChannelId) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; debug!(channel=%channel_id, "EOF"); let _ = self.send_command(RCCommand::Channel(channel_id, ChannelOperation::Eof)); Ok(()) } pub async fn _channel_signal( &mut self, server_channel_id: ServerChannelId, signal: Sig, ) -> Result<()> { let channel_id = self.map_channel(&server_channel_id)?; debug!(channel=%channel_id, ?signal, "Signal"); self.send_command_and_wait(RCCommand::Channel( channel_id, ChannelOperation::Signal(signal), )) .await?; Ok(()) } fn send_command(&mut self, command: RCCommand) -> Result<(), RCCommand> { self.rc_tx.send((command, None)).map_err(|e| e.0 .0) } async fn send_command_and_wait(&mut self, command: RCCommand) -> Result<(), SshClientError> { let (tx, rx) = oneshot::channel(); let mut cmd = match self.rc_tx.send((command, Some(tx))) { Ok(_) => PendingCommand::Waiting(rx), Err(_) => PendingCommand::Failed, }; loop { tokio::select! { result = &mut cmd => { return result } event = self.get_next_event() => { match event { Some(event) => { self.handle_event(event).await.map_err(SshClientError::from)? } None => {Err(SshClientError::MpscError)?} }; } } } } pub async fn _disconnect(&mut self) { debug!("Client disconnect requested"); self.request_disconnect().await; } async fn request_disconnect(&mut self) { debug!("Disconnecting"); let _ = self.rc_abort_tx.send(()); if self.rc_state != RCState::NotInitialized && self.rc_state != RCState::Disconnected { let _ = self.send_command(RCCommand::Disconnect); } } async fn disconnect_server(&mut self) { let all_channels = std::mem::take(&mut self.all_channels); let channels = all_channels .into_iter() .map(|x| self.map_channel_reverse(&x)) .filter_map(|x| x.ok()) .collect::>(); let _ = self .maybe_with_session(|handle| async move { for ch in channels { let _ = handle.close(ch.0).await; } Ok(()) }) .await; self.session_handle = None; } } impl Drop for ServerSession { fn drop(&mut self) { let _ = self.rc_abort_tx.send(()); info!("Closed session"); debug!("Dropped"); } } pub enum PendingCommand { Waiting(oneshot::Receiver>), Failed, } impl Future for PendingCommand { type Output = Result<(), SshClientError>; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { match self.get_mut() { PendingCommand::Waiting(ref mut rx) => match Pin::new(rx).poll(cx) { Poll::Ready(result) => { Poll::Ready(result.unwrap_or(Err(SshClientError::MpscError))) } Poll::Pending => Poll::Pending, }, PendingCommand::Failed => Poll::Ready(Err(SshClientError::MpscError)), } } }