mirror of
https://github.com/warp-tech/warpgate.git
synced 2025-10-06 05:17:42 +08:00
1610 lines
58 KiB
Rust
1610 lines
58 KiB
Rust
use std::borrow::Cow;
|
|
use std::collections::hash_map::Entry::Vacant;
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::net::{Ipv4Addr, SocketAddr};
|
|
use std::pin::Pin;
|
|
use std::str::FromStr;
|
|
use std::sync::Arc;
|
|
use std::task::Poll;
|
|
|
|
use ansi_term::Colour;
|
|
use anyhow::{Context, Result};
|
|
use bimap::BiMap;
|
|
use bytes::Bytes;
|
|
use futures::{Future, FutureExt};
|
|
use russh::{CryptoVec, MethodSet, Sig};
|
|
use russh_keys::key::PublicKey;
|
|
use russh_keys::PublicKeyBase64;
|
|
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
|
use tokio::sync::{broadcast, oneshot, Mutex};
|
|
use tracing::*;
|
|
use uuid::Uuid;
|
|
use warpgate_common::auth::{AuthCredential, AuthResult, AuthSelector, AuthState, CredentialKind};
|
|
use warpgate_common::eventhub::{EventHub, EventSender, EventSubscription};
|
|
use warpgate_common::{
|
|
Secret, SessionId, SshHostKeyVerificationMode, Target, TargetOptions, TargetSSHOptions,
|
|
WarpgateError,
|
|
};
|
|
use warpgate_core::recordings::{
|
|
self, ConnectionRecorder, TerminalRecorder, TerminalRecordingStreamId, TrafficConnectionParams,
|
|
TrafficRecorder,
|
|
};
|
|
use warpgate_core::{authorize_ticket, consume_ticket, Services, WarpgateServerHandle};
|
|
|
|
use super::channel_writer::ChannelWriter;
|
|
use super::russh_handler::ServerHandlerEvent;
|
|
use super::service_output::ServiceOutput;
|
|
use super::session_handle::SessionHandleCommand;
|
|
use crate::compat::ContextExt;
|
|
use crate::server::service_output::ERASE_PROGRESS_SPINNER;
|
|
use crate::{
|
|
ChannelOperation, ConnectionError, DirectTCPIPParams, PtyRequest, RCCommand, RCCommandReply,
|
|
RCEvent, RCState, RemoteClient, ServerChannelId, SshClientError, X11Request,
|
|
};
|
|
|
|
#[derive(Clone)]
|
|
#[allow(clippy::large_enum_variant)]
|
|
enum TargetSelection {
|
|
None,
|
|
NotFound(String),
|
|
Found(Target, TargetSSHOptions),
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
enum Event {
|
|
Command(SessionHandleCommand),
|
|
ServerHandler(ServerHandlerEvent),
|
|
ConsoleInput(Bytes),
|
|
ServiceOutput(Bytes),
|
|
Client(RCEvent),
|
|
}
|
|
|
|
enum KeyboardInteractiveState {
|
|
None,
|
|
OtpRequested,
|
|
WebAuthRequested(broadcast::Receiver<AuthResult>),
|
|
}
|
|
|
|
pub struct ServerSession {
|
|
pub id: SessionId,
|
|
username: Option<String>,
|
|
session_handle: Option<russh::server::Handle>,
|
|
pty_channels: Vec<Uuid>,
|
|
all_channels: Vec<Uuid>,
|
|
channel_recorders: HashMap<Uuid, TerminalRecorder>,
|
|
channel_map: BiMap<ServerChannelId, Uuid>,
|
|
channel_pty_size_map: HashMap<Uuid, PtyRequest>,
|
|
rc_tx: UnboundedSender<(RCCommand, Option<RCCommandReply>)>,
|
|
rc_abort_tx: UnboundedSender<()>,
|
|
rc_state: RCState,
|
|
remote_address: SocketAddr,
|
|
services: Services,
|
|
server_handle: Arc<Mutex<WarpgateServerHandle>>,
|
|
target: TargetSelection,
|
|
traffic_recorders: HashMap<(String, u32), TrafficRecorder>,
|
|
traffic_connection_recorders: HashMap<Uuid, ConnectionRecorder>,
|
|
hub: EventHub<Event>,
|
|
event_sender: EventSender<Event>,
|
|
main_event_subscription: EventSubscription<Event>,
|
|
service_output: ServiceOutput,
|
|
channel_writer: ChannelWriter,
|
|
auth_state: Option<Arc<Mutex<AuthState>>>,
|
|
keyboard_interactive_state: KeyboardInteractiveState,
|
|
}
|
|
|
|
fn session_debug_tag(id: &SessionId, remote_address: &SocketAddr) -> String {
|
|
format!("[{} - {}]", id, remote_address)
|
|
}
|
|
|
|
impl std::fmt::Debug for ServerSession {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", session_debug_tag(&self.id, &self.remote_address))
|
|
}
|
|
}
|
|
|
|
impl ServerSession {
|
|
pub async fn start(
|
|
remote_address: SocketAddr,
|
|
services: &Services,
|
|
server_handle: Arc<Mutex<WarpgateServerHandle>>,
|
|
mut session_handle_rx: UnboundedReceiver<SessionHandleCommand>,
|
|
mut handler_event_rx: UnboundedReceiver<ServerHandlerEvent>,
|
|
) -> Result<impl Future<Output = Result<()>>> {
|
|
let id = server_handle.lock().await.id();
|
|
|
|
let _span = info_span!("SSH", session=%id);
|
|
let _enter = _span.enter();
|
|
|
|
let mut rc_handles = RemoteClient::create(id, services.clone());
|
|
|
|
let (hub, event_sender) = EventHub::setup();
|
|
let main_event_subscription = hub
|
|
.subscribe(|e| !matches!(e, Event::ConsoleInput(_)))
|
|
.await;
|
|
|
|
let mut this = Self {
|
|
id,
|
|
username: None,
|
|
session_handle: None,
|
|
pty_channels: vec![],
|
|
all_channels: vec![],
|
|
channel_recorders: HashMap::new(),
|
|
channel_map: BiMap::new(),
|
|
channel_pty_size_map: HashMap::new(),
|
|
rc_tx: rc_handles.command_tx.clone(),
|
|
rc_abort_tx: rc_handles.abort_tx,
|
|
rc_state: RCState::NotInitialized,
|
|
remote_address,
|
|
services: services.clone(),
|
|
server_handle,
|
|
target: TargetSelection::None,
|
|
traffic_recorders: HashMap::new(),
|
|
traffic_connection_recorders: HashMap::new(),
|
|
hub,
|
|
event_sender: event_sender.clone(),
|
|
main_event_subscription,
|
|
service_output: ServiceOutput::new(),
|
|
channel_writer: ChannelWriter::new(),
|
|
auth_state: None,
|
|
keyboard_interactive_state: KeyboardInteractiveState::None,
|
|
};
|
|
|
|
let mut so_rx = this.service_output.subscribe();
|
|
let so_sender = event_sender.clone();
|
|
tokio::spawn(async move {
|
|
loop {
|
|
match so_rx.recv().await {
|
|
Ok(data) => {
|
|
if so_sender
|
|
.send_once(Event::ServiceOutput(data))
|
|
.await
|
|
.is_err()
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
Err(broadcast::error::RecvError::Closed) => break,
|
|
Err(_) => (),
|
|
}
|
|
}
|
|
});
|
|
|
|
let name = format!("SSH {} session control", id);
|
|
tokio::task::Builder::new().name(&name).spawn({
|
|
let sender = event_sender.clone();
|
|
async move {
|
|
while let Some(command) = session_handle_rx.recv().await {
|
|
if sender.send_once(Event::Command(command)).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
let name = format!("SSH {} client events", id);
|
|
tokio::task::Builder::new().name(&name).spawn({
|
|
let sender = event_sender.clone();
|
|
async move {
|
|
while let Some(e) = rc_handles.event_rx.recv().await {
|
|
if sender.send_once(Event::Client(e)).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
let name = format!("SSH {} server handler events", id);
|
|
tokio::task::Builder::new().name(&name).spawn({
|
|
let sender = event_sender.clone();
|
|
async move {
|
|
while let Some(e) = handler_event_rx.recv().await {
|
|
if sender.send_once(Event::ServerHandler(e)).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok(async move {
|
|
while let Some(event) = this.get_next_event().await {
|
|
this.handle_event(event).await?;
|
|
}
|
|
debug!("No more events");
|
|
Ok::<_, anyhow::Error>(())
|
|
})
|
|
}
|
|
|
|
async fn get_next_event(&mut self) -> Option<Event> {
|
|
self.main_event_subscription.recv().await
|
|
}
|
|
|
|
async fn get_auth_state(&mut self, username: &str) -> Result<Arc<Mutex<AuthState>>> {
|
|
#[allow(clippy::unwrap_used)]
|
|
if self.auth_state.is_none()
|
|
|| self.auth_state.as_ref().unwrap().lock().await.username() != username
|
|
{
|
|
let state = self
|
|
.services
|
|
.auth_state_store
|
|
.lock()
|
|
.await
|
|
.create(username, crate::PROTOCOL_NAME)
|
|
.await?
|
|
.1;
|
|
self.auth_state = Some(state);
|
|
}
|
|
#[allow(clippy::unwrap_used)]
|
|
Ok(self.auth_state.as_ref().map(Clone::clone).unwrap())
|
|
}
|
|
|
|
pub fn make_logging_span(&self) -> tracing::Span {
|
|
match self.username {
|
|
Some(ref username) => info_span!("SSH", session=%self.id, session_username=%username),
|
|
None => info_span!("SSH", session=%self.id),
|
|
}
|
|
}
|
|
|
|
fn map_channel(&self, ch: &ServerChannelId) -> Result<Uuid> {
|
|
self.channel_map
|
|
.get_by_left(ch)
|
|
.cloned()
|
|
.ok_or_else(|| anyhow::anyhow!("Channel not known"))
|
|
}
|
|
|
|
fn map_channel_reverse(&self, ch: &Uuid) -> Result<ServerChannelId> {
|
|
self.channel_map
|
|
.get_by_right(ch)
|
|
.cloned()
|
|
.ok_or_else(|| anyhow::anyhow!("Channel not known"))
|
|
}
|
|
|
|
pub async fn emit_service_message(&mut self, msg: &str) -> Result<()> {
|
|
debug!("Service message: {}", msg);
|
|
|
|
self.service_output.emit_output(Bytes::from(format!(
|
|
"{}{} {}\r\n",
|
|
ERASE_PROGRESS_SPINNER,
|
|
Colour::Black.on(Colour::White).paint(" Warpgate "),
|
|
msg.replace('\n', "\r\n"),
|
|
)));
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn emit_pty_output(&mut self, data: &[u8]) -> Result<()> {
|
|
let channels = self.pty_channels.clone();
|
|
for channel in channels {
|
|
let channel = self.map_channel_reverse(&channel)?;
|
|
if let Some(session) = self.session_handle.clone() {
|
|
self.channel_writer
|
|
.write(session, channel.0, CryptoVec::from_slice(data));
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn maybe_connect_remote(&mut self) -> Result<()> {
|
|
match self.target.clone() {
|
|
TargetSelection::None => {
|
|
anyhow::bail!("Invalid session state (target not set)")
|
|
}
|
|
TargetSelection::NotFound(name) => {
|
|
self.emit_service_message(&format!("Selected target not found: {name}"))
|
|
.await?;
|
|
self.disconnect_server().await;
|
|
anyhow::bail!("Target not found: {}", name);
|
|
}
|
|
TargetSelection::Found(target, ssh_options) => {
|
|
if self.rc_state == RCState::NotInitialized {
|
|
self.connect_remote(target, ssh_options).await?;
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn connect_remote(
|
|
&mut self,
|
|
target: Target,
|
|
ssh_options: TargetSSHOptions,
|
|
) -> Result<()> {
|
|
self.rc_state = RCState::Connecting;
|
|
self.send_command(RCCommand::Connect(ssh_options))
|
|
.map_err(|_| anyhow::anyhow!("cannot send command"))?;
|
|
self.service_output.show_progress();
|
|
self.emit_service_message(&format!("Selected target: {}", target.name))
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn handle_event<'a>(
|
|
&'a mut self,
|
|
event: Event,
|
|
) -> Pin<Box<dyn Future<Output = Result<(), WarpgateError>> + Send + 'a>> {
|
|
async move {
|
|
match event {
|
|
Event::Client(RCEvent::Done) => Err(WarpgateError::SessionEnd)?,
|
|
Event::ServerHandler(ServerHandlerEvent::Disconnect) => {
|
|
Err(WarpgateError::SessionEnd)?
|
|
}
|
|
Event::Client(e) => {
|
|
debug!(event=?e, "Event");
|
|
let span = self.make_logging_span();
|
|
if let Err(err) = self.handle_remote_event(e).instrument(span).await {
|
|
error!("Event handler error: {:?}", err);
|
|
// break;
|
|
}
|
|
}
|
|
Event::ServerHandler(e) => {
|
|
let span = self.make_logging_span();
|
|
if let Err(err) = self.handle_server_handler_event(e).instrument(span).await {
|
|
error!("Event handler error: {:?}", err);
|
|
// break;
|
|
}
|
|
}
|
|
Event::Command(command) => {
|
|
debug!(?command, "Session control");
|
|
if let Err(err) = self.handle_session_control(command).await {
|
|
error!("Event handler error: {:?}", err);
|
|
// break;
|
|
}
|
|
}
|
|
Event::ServiceOutput(data) => {
|
|
let _ = self.emit_pty_output(&data).await;
|
|
}
|
|
Event::ConsoleInput(_) => (),
|
|
}
|
|
Ok(())
|
|
}
|
|
.boxed()
|
|
}
|
|
|
|
async fn handle_server_handler_event(&mut self, event: ServerHandlerEvent) -> Result<()> {
|
|
match event {
|
|
ServerHandlerEvent::Authenticated(handle) => {
|
|
self.session_handle = Some(handle.0);
|
|
}
|
|
|
|
ServerHandlerEvent::ChannelOpenSession(server_channel_id, reply) => {
|
|
let channel = Uuid::new_v4();
|
|
self.channel_map.insert(server_channel_id, channel);
|
|
|
|
info!(%channel, "Opening session channel");
|
|
return match self
|
|
.send_command_and_wait(RCCommand::Channel(channel, ChannelOperation::OpenShell))
|
|
.await
|
|
{
|
|
Ok(()) => {
|
|
self.all_channels.push(channel);
|
|
let _ = reply.send(true);
|
|
Ok(())
|
|
}
|
|
Err(SshClientError::ChannelFailure) => {
|
|
let _ = reply.send(false);
|
|
Ok(())
|
|
}
|
|
Err(x) => Err(x.into()),
|
|
};
|
|
}
|
|
|
|
ServerHandlerEvent::SubsystemRequest(server_channel_id, name, _) => {
|
|
self._channel_subsystem_request(server_channel_id, name)
|
|
.await?;
|
|
}
|
|
|
|
ServerHandlerEvent::PtyRequest(server_channel_id, request, _) => {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
self.channel_pty_size_map
|
|
.insert(channel_id, request.clone());
|
|
if let Some(recorder) = self.channel_recorders.get_mut(&channel_id) {
|
|
if let Err(error) = recorder
|
|
.write_pty_resize(request.col_width, request.row_height)
|
|
.await
|
|
{
|
|
error!(%channel_id, ?error, "Failed to record terminal data");
|
|
self.channel_recorders.remove(&channel_id);
|
|
}
|
|
}
|
|
self.send_command_and_wait(RCCommand::Channel(
|
|
channel_id,
|
|
ChannelOperation::RequestPty(request),
|
|
))
|
|
.await?;
|
|
let _ = self
|
|
.session_handle
|
|
.as_mut()
|
|
.context("Invalid session state")?
|
|
.channel_success(server_channel_id.0)
|
|
.await;
|
|
self.pty_channels.push(channel_id);
|
|
}
|
|
|
|
ServerHandlerEvent::ShellRequest(server_channel_id, _) => {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
let _ = self.maybe_connect_remote().await;
|
|
|
|
let _ = self.send_command(RCCommand::Channel(
|
|
channel_id,
|
|
ChannelOperation::RequestShell,
|
|
));
|
|
|
|
self.start_terminal_recording(
|
|
channel_id,
|
|
format!("shell-channel-{}", server_channel_id.0),
|
|
)
|
|
.await;
|
|
|
|
info!(%channel_id, "Opening shell");
|
|
|
|
let _ = self
|
|
.session_handle
|
|
.as_mut()
|
|
.context("Invalid session state")?
|
|
.channel_success(server_channel_id.0)
|
|
.await;
|
|
}
|
|
|
|
ServerHandlerEvent::AuthPublicKey(username, key, reply) => {
|
|
let _ = reply.send(self._auth_publickey(username, key).await);
|
|
}
|
|
|
|
ServerHandlerEvent::AuthPassword(username, password, reply) => {
|
|
let _ = reply.send(self._auth_password(username, password).await);
|
|
}
|
|
|
|
ServerHandlerEvent::AuthKeyboardInteractive(username, response, reply) => {
|
|
let _ = reply.send(self._auth_keyboard_interactive(username, response).await);
|
|
}
|
|
|
|
ServerHandlerEvent::Data(channel, data, _) => {
|
|
self._data(channel, data).await?;
|
|
}
|
|
|
|
ServerHandlerEvent::ExtendedData(channel, data, code, _) => {
|
|
self._extended_data(channel, code, data).await?;
|
|
}
|
|
|
|
ServerHandlerEvent::ChannelClose(channel, _) => {
|
|
self._channel_close(channel).await?;
|
|
}
|
|
|
|
ServerHandlerEvent::ChannelEof(channel, _) => {
|
|
self._channel_eof(channel).await?;
|
|
}
|
|
|
|
ServerHandlerEvent::WindowChangeRequest(channel, request, _) => {
|
|
self._window_change_request(channel, request).await?;
|
|
}
|
|
|
|
ServerHandlerEvent::Signal(channel, signal, _) => {
|
|
self._channel_signal(channel, signal).await?;
|
|
}
|
|
|
|
ServerHandlerEvent::ExecRequest(channel, data, _) => {
|
|
self._channel_exec_request(channel, data).await?;
|
|
}
|
|
|
|
ServerHandlerEvent::ChannelOpenDirectTcpIp(channel, params, reply) => {
|
|
let _ = reply.send(self._channel_open_direct_tcpip(channel, params).await?);
|
|
}
|
|
|
|
ServerHandlerEvent::EnvRequest(channel, name, value, _) => {
|
|
self._channel_env_request(channel, name, value).await?;
|
|
}
|
|
|
|
ServerHandlerEvent::X11Request(channel, request, _) => {
|
|
self._channel_x11_request(channel, request).await?;
|
|
}
|
|
|
|
ServerHandlerEvent::TcpIpForward(address, port, reply) => {
|
|
self._tcpip_forward(address, port).await?;
|
|
let _ = reply.send(true);
|
|
}
|
|
|
|
ServerHandlerEvent::CancelTcpIpForward(address, port, reply) => {
|
|
self._cancel_tcpip_forward(address, port).await?;
|
|
let _ = reply.send(true);
|
|
}
|
|
|
|
ServerHandlerEvent::Disconnect => (),
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn handle_session_control(&mut self, command: SessionHandleCommand) -> Result<()> {
|
|
match command {
|
|
SessionHandleCommand::Close => {
|
|
let _ = self.emit_service_message("Session closed by admin").await;
|
|
info!("Session closed by admin");
|
|
self.request_disconnect().await;
|
|
self.disconnect_server().await;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn handle_remote_event(&mut self, event: RCEvent) -> Result<()> {
|
|
match event {
|
|
RCEvent::State(state) => {
|
|
self.rc_state = state;
|
|
match &self.rc_state {
|
|
RCState::Connected => {
|
|
self.service_output.hide_progress().await;
|
|
self.service_output.emit_output(Bytes::from(format!(
|
|
"{}{}\r\n",
|
|
ERASE_PROGRESS_SPINNER,
|
|
Colour::Black
|
|
.on(Colour::Green)
|
|
.paint(" ✓ Warpgate connected ")
|
|
)));
|
|
}
|
|
RCState::Disconnected => {
|
|
self.service_output.hide_progress().await;
|
|
self.disconnect_server().await;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
RCEvent::ConnectionError(error) => {
|
|
self.service_output.hide_progress().await;
|
|
|
|
match error {
|
|
ConnectionError::HostKeyMismatch {
|
|
received_key_type,
|
|
received_key_base64,
|
|
known_key_type,
|
|
known_key_base64,
|
|
} => {
|
|
let msg = format!(
|
|
concat!(
|
|
"Host key doesn't match the stored one.\n",
|
|
"Stored key ({}): {}\n",
|
|
"Received key ({}): {}",
|
|
),
|
|
known_key_type,
|
|
known_key_base64,
|
|
received_key_type,
|
|
received_key_base64
|
|
);
|
|
self.emit_service_message(&msg).await?;
|
|
self.emit_service_message(
|
|
"If you know that the key is correct (e.g. it has been changed),",
|
|
)
|
|
.await?;
|
|
self.emit_service_message(
|
|
"you can remove the old key in the Warpgate management UI and try again",
|
|
)
|
|
.await?;
|
|
}
|
|
error => {
|
|
self.service_output.emit_output(Bytes::from(format!(
|
|
"{}{} {}\r\n",
|
|
ERASE_PROGRESS_SPINNER,
|
|
Colour::Black.on(Colour::Red).paint(" Connection failed "),
|
|
error
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
RCEvent::Error(e) => {
|
|
self.service_output.hide_progress().await;
|
|
let _ = self.emit_service_message(&format!("Error: {e}")).await;
|
|
self.disconnect_server().await;
|
|
}
|
|
RCEvent::Output(channel, data) => {
|
|
if let Some(recorder) = self.channel_recorders.get_mut(&channel) {
|
|
if let Err(error) = recorder
|
|
.write(TerminalRecordingStreamId::Output, &data)
|
|
.await
|
|
{
|
|
error!(%channel, ?error, "Failed to record terminal data");
|
|
self.channel_recorders.remove(&channel);
|
|
}
|
|
}
|
|
|
|
if let Some(recorder) = self.traffic_connection_recorders.get_mut(&channel) {
|
|
if let Err(error) = recorder.write_rx(&data).await {
|
|
error!(%channel, ?error, "Failed to record traffic data");
|
|
self.traffic_connection_recorders.remove(&channel);
|
|
}
|
|
}
|
|
|
|
let server_channel_id = self.map_channel_reverse(&channel)?;
|
|
if let Some(session) = self.session_handle.as_mut() {
|
|
let _ = session
|
|
.data(server_channel_id.0, CryptoVec::from_slice(&data))
|
|
.await;
|
|
}
|
|
}
|
|
RCEvent::Success(channel) => {
|
|
let server_channel_id = self.map_channel_reverse(&channel)?;
|
|
self.maybe_with_session(|handle| async move {
|
|
handle
|
|
.channel_success(server_channel_id.0)
|
|
.await
|
|
.context("failed to send data")
|
|
})
|
|
.await?;
|
|
}
|
|
RCEvent::ChannelFailure(channel) => {
|
|
let server_channel_id = self.map_channel_reverse(&channel)?;
|
|
self.maybe_with_session(|handle| async move {
|
|
handle
|
|
.channel_failure(server_channel_id.0)
|
|
.await
|
|
.context("failed to send data")
|
|
})
|
|
.await?;
|
|
}
|
|
RCEvent::Close(channel) => {
|
|
let server_channel_id = self.map_channel_reverse(&channel)?;
|
|
let _ = self
|
|
.maybe_with_session(|handle| async move {
|
|
handle
|
|
.close(server_channel_id.0)
|
|
.await
|
|
.context("failed to close ch")
|
|
})
|
|
.await;
|
|
}
|
|
RCEvent::Eof(channel) => {
|
|
let server_channel_id = self.map_channel_reverse(&channel)?;
|
|
self.maybe_with_session(|handle| async move {
|
|
handle
|
|
.eof(server_channel_id.0)
|
|
.await
|
|
.context("failed to send eof")
|
|
})
|
|
.await?;
|
|
}
|
|
RCEvent::ExitStatus(channel, code) => {
|
|
let server_channel_id = self.map_channel_reverse(&channel)?;
|
|
self.maybe_with_session(|handle| async move {
|
|
handle
|
|
.exit_status_request(server_channel_id.0, code)
|
|
.await
|
|
.context("failed to send exit status")
|
|
})
|
|
.await?;
|
|
}
|
|
RCEvent::ExitSignal {
|
|
channel,
|
|
signal_name,
|
|
core_dumped,
|
|
error_message,
|
|
lang_tag,
|
|
} => {
|
|
let server_channel_id = self.map_channel_reverse(&channel)?;
|
|
self.maybe_with_session(|handle| async move {
|
|
handle
|
|
.exit_signal_request(
|
|
server_channel_id.0,
|
|
signal_name,
|
|
core_dumped,
|
|
error_message,
|
|
lang_tag,
|
|
)
|
|
.await
|
|
.context("failed to send exit status")?;
|
|
Ok(())
|
|
})
|
|
.await?;
|
|
}
|
|
RCEvent::Done => {}
|
|
RCEvent::ExtendedData { channel, data, ext } => {
|
|
if let Some(recorder) = self.channel_recorders.get_mut(&channel) {
|
|
if let Err(error) = recorder
|
|
.write(TerminalRecordingStreamId::Error, &data)
|
|
.await
|
|
{
|
|
error!(%channel, ?error, "Failed to record session data");
|
|
self.channel_recorders.remove(&channel);
|
|
}
|
|
}
|
|
let server_channel_id = self.map_channel_reverse(&channel)?;
|
|
self.maybe_with_session(|handle| async move {
|
|
handle
|
|
.extended_data(server_channel_id.0, ext, CryptoVec::from_slice(&data))
|
|
.await
|
|
.map_err(|_| ())
|
|
.context("failed to send extended data")?;
|
|
Ok(())
|
|
})
|
|
.await?;
|
|
}
|
|
RCEvent::HostKeyReceived(key) => {
|
|
self.emit_service_message(&format!(
|
|
"Host key ({}): {}",
|
|
key.name(),
|
|
key.public_key_base64()
|
|
))
|
|
.await?;
|
|
}
|
|
RCEvent::HostKeyUnknown(key, reply) => {
|
|
self.handle_unknown_host_key(key, reply).await?;
|
|
}
|
|
RCEvent::ForwardedTcpIp(id, params) => {
|
|
if let Some(session) = &mut self.session_handle {
|
|
let server_channel = session
|
|
.channel_open_forwarded_tcpip(
|
|
params.connected_address,
|
|
params.connected_port,
|
|
params.originator_address.clone(),
|
|
params.originator_port,
|
|
)
|
|
.await?;
|
|
|
|
self.channel_map
|
|
.insert(ServerChannelId(server_channel.id()), id);
|
|
self.all_channels.push(id);
|
|
|
|
let recorder = self
|
|
.traffic_recorder_for(
|
|
¶ms.originator_address,
|
|
params.originator_port,
|
|
"forwarded-tcpip",
|
|
)
|
|
.await;
|
|
if let Some(recorder) = recorder {
|
|
#[allow(clippy::unwrap_used)]
|
|
let mut recorder = recorder.connection(TrafficConnectionParams {
|
|
dst_addr: Ipv4Addr::from_str("2.2.2.2").unwrap(),
|
|
dst_port: params.connected_port as u16,
|
|
src_addr: Ipv4Addr::from_str("1.1.1.1").unwrap(),
|
|
src_port: params.originator_port as u16,
|
|
});
|
|
if let Err(error) = recorder.write_connection_setup().await {
|
|
error!(channel=%id, ?error, "Failed to record connection setup");
|
|
}
|
|
self.traffic_connection_recorders.insert(id, recorder);
|
|
}
|
|
}
|
|
}
|
|
RCEvent::X11(id, originator_address, originator_port) => {
|
|
if let Some(session) = &mut self.session_handle {
|
|
let server_channel = session
|
|
.channel_open_x11(originator_address, originator_port)
|
|
.await?;
|
|
|
|
self.channel_map
|
|
.insert(ServerChannelId(server_channel.id()), id);
|
|
self.all_channels.push(id);
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_unknown_host_key(
|
|
&mut self,
|
|
key: PublicKey,
|
|
reply: oneshot::Sender<bool>,
|
|
) -> Result<()> {
|
|
self.service_output.hide_progress().await;
|
|
|
|
let mode = self
|
|
.services
|
|
.config
|
|
.lock()
|
|
.await
|
|
.store
|
|
.ssh
|
|
.host_key_verification;
|
|
|
|
if mode == SshHostKeyVerificationMode::AutoAccept {
|
|
let _ = reply.send(true);
|
|
info!("Accepted untrusted host key (auto-accept is enabled)");
|
|
return Ok(());
|
|
}
|
|
|
|
if mode == SshHostKeyVerificationMode::AutoReject {
|
|
let _ = reply.send(false);
|
|
info!("Rejected untrusted host key (auto-reject is enabled)");
|
|
return Ok(());
|
|
}
|
|
|
|
if self.pty_channels.is_empty() {
|
|
warn!("Target host key is not trusted, but there is no active PTY channel to show the trust prompt on.");
|
|
warn!(
|
|
"Connect to this target with an interactive session once to accept the host key."
|
|
);
|
|
self.request_disconnect().await;
|
|
anyhow::bail!("No PTY channel to show an interactive prompt on")
|
|
}
|
|
|
|
self.emit_service_message(&format!(
|
|
"There is no trusted {} key for this host.",
|
|
key.name()
|
|
))
|
|
.await?;
|
|
self.emit_service_message("Trust this key? (y/n)").await?;
|
|
|
|
let mut sub = self
|
|
.hub
|
|
.subscribe(|e| matches!(e, Event::ConsoleInput(_)))
|
|
.await;
|
|
|
|
let mut service_output = self.service_output.clone();
|
|
tokio::spawn(async move {
|
|
loop {
|
|
match sub.recv().await {
|
|
Some(Event::ConsoleInput(data)) => {
|
|
if data == "y".as_bytes() {
|
|
let _ = reply.send(true);
|
|
break;
|
|
} else if data == "n".as_bytes() {
|
|
let _ = reply.send(false);
|
|
break;
|
|
}
|
|
}
|
|
None => break,
|
|
_ => (),
|
|
}
|
|
}
|
|
service_output.show_progress();
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn maybe_with_session<'a, FN, FT, R>(&'a mut self, f: FN) -> Result<Option<R>>
|
|
where
|
|
FN: FnOnce(&'a mut russh::server::Handle) -> FT + 'a,
|
|
FT: futures::Future<Output = Result<R>>,
|
|
{
|
|
if let Some(handle) = &mut self.session_handle {
|
|
return Ok(Some(f(handle).await?));
|
|
}
|
|
Ok(None)
|
|
}
|
|
|
|
async fn _channel_open_direct_tcpip(
|
|
&mut self,
|
|
channel: ServerChannelId,
|
|
params: DirectTCPIPParams,
|
|
) -> Result<bool> {
|
|
let uuid = Uuid::new_v4();
|
|
self.channel_map.insert(channel, uuid);
|
|
|
|
info!(%channel, "Opening direct TCP/IP channel from {}:{} to {}:{}", params.originator_address, params.originator_port, params.host_to_connect, params.port_to_connect);
|
|
|
|
let _ = self.maybe_connect_remote().await;
|
|
|
|
match self
|
|
.send_command_and_wait(RCCommand::Channel(
|
|
uuid,
|
|
ChannelOperation::OpenDirectTCPIP(params.clone()),
|
|
))
|
|
.await
|
|
{
|
|
Ok(()) => {
|
|
self.all_channels.push(uuid);
|
|
|
|
let recorder = self
|
|
.traffic_recorder_for(
|
|
¶ms.host_to_connect,
|
|
params.port_to_connect,
|
|
"direct-tcpip",
|
|
)
|
|
.await;
|
|
if let Some(recorder) = recorder {
|
|
#[allow(clippy::unwrap_used)]
|
|
let mut recorder = recorder.connection(TrafficConnectionParams {
|
|
dst_addr: Ipv4Addr::from_str("2.2.2.2").unwrap(),
|
|
dst_port: params.port_to_connect as u16,
|
|
src_addr: Ipv4Addr::from_str("1.1.1.1").unwrap(),
|
|
src_port: params.originator_port as u16,
|
|
});
|
|
if let Err(error) = recorder.write_connection_setup().await {
|
|
error!(%channel, ?error, "Failed to record connection setup");
|
|
}
|
|
self.traffic_connection_recorders.insert(uuid, recorder);
|
|
}
|
|
|
|
Ok(true)
|
|
}
|
|
Err(SshClientError::ChannelFailure) => Ok(false),
|
|
Err(x) => Err(x.into()),
|
|
}
|
|
}
|
|
|
|
async fn _window_change_request(
|
|
&mut self,
|
|
server_channel_id: ServerChannelId,
|
|
request: PtyRequest,
|
|
) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
self.channel_pty_size_map
|
|
.insert(channel_id, request.clone());
|
|
if let Some(recorder) = self.channel_recorders.get_mut(&channel_id) {
|
|
if let Err(error) = recorder
|
|
.write_pty_resize(request.col_width, request.row_height)
|
|
.await
|
|
{
|
|
error!(%channel_id, ?error, "Failed to record terminal data");
|
|
self.channel_recorders.remove(&channel_id);
|
|
}
|
|
}
|
|
self.send_command_and_wait(RCCommand::Channel(
|
|
channel_id,
|
|
ChannelOperation::ResizePty(request),
|
|
))
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn _channel_exec_request(
|
|
&mut self,
|
|
server_channel_id: ServerChannelId,
|
|
data: Bytes,
|
|
) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
match std::str::from_utf8(&data) {
|
|
Err(e) => {
|
|
error!(channel=%channel_id, ?data, "Requested exec - invalid UTF-8");
|
|
anyhow::bail!(e)
|
|
}
|
|
Ok::<&str, _>(command) => {
|
|
debug!(channel=%channel_id, %command, "Requested exec");
|
|
let _ = self.maybe_connect_remote().await;
|
|
let _ = self.send_command(RCCommand::Channel(
|
|
channel_id,
|
|
ChannelOperation::RequestExec(command.to_string()),
|
|
));
|
|
}
|
|
}
|
|
|
|
self.start_terminal_recording(channel_id, format!("exec-channel-{}", server_channel_id.0))
|
|
.await;
|
|
Ok(())
|
|
}
|
|
|
|
async fn start_terminal_recording(&mut self, channel_id: Uuid, name: String) {
|
|
match async {
|
|
let mut recorder = self
|
|
.services
|
|
.recordings
|
|
.lock()
|
|
.await
|
|
.start::<TerminalRecorder>(&self.id, name)
|
|
.await?;
|
|
if let Some(request) = self.channel_pty_size_map.get(&channel_id) {
|
|
recorder
|
|
.write_pty_resize(request.col_width, request.row_height)
|
|
.await?;
|
|
}
|
|
Ok::<_, recordings::Error>(recorder)
|
|
}
|
|
.await
|
|
{
|
|
Ok(recorder) => {
|
|
self.channel_recorders.insert(channel_id, recorder);
|
|
}
|
|
Err(error) => match error {
|
|
recordings::Error::Disabled => (),
|
|
error => error!(channel=%channel_id, ?error, "Failed to start recording"),
|
|
},
|
|
}
|
|
}
|
|
|
|
async fn _channel_x11_request(
|
|
&mut self,
|
|
server_channel_id: ServerChannelId,
|
|
request: X11Request,
|
|
) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
debug!(channel=%channel_id, "Requested X11");
|
|
let _ = self.maybe_connect_remote().await;
|
|
self.send_command_and_wait(RCCommand::Channel(
|
|
channel_id,
|
|
ChannelOperation::RequestX11(request),
|
|
))
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn _channel_env_request(
|
|
&mut self,
|
|
server_channel_id: ServerChannelId,
|
|
name: String,
|
|
value: String,
|
|
) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
debug!(channel=%channel_id, %name, %value, "Environment");
|
|
self.send_command_and_wait(RCCommand::Channel(
|
|
channel_id,
|
|
ChannelOperation::RequestEnv(name, value),
|
|
))
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn traffic_recorder_for(
|
|
&mut self,
|
|
host: &str,
|
|
port: u32,
|
|
tag: &str,
|
|
) -> Option<&mut TrafficRecorder> {
|
|
let host = host.to_owned();
|
|
if let Vacant(e) = self.traffic_recorders.entry((host.clone(), port)) {
|
|
match self
|
|
.services
|
|
.recordings
|
|
.lock()
|
|
.await
|
|
.start(&self.id, format!("{tag}-{host}-{port}"))
|
|
.await
|
|
{
|
|
Ok(recorder) => {
|
|
e.insert(recorder);
|
|
}
|
|
Err(error) => {
|
|
error!(%host, %port, ?error, "Failed to start recording");
|
|
}
|
|
}
|
|
}
|
|
self.traffic_recorders.get_mut(&(host, port))
|
|
}
|
|
|
|
pub async fn _channel_shell_request_nowait(
|
|
&mut self,
|
|
server_channel_id: ServerChannelId,
|
|
) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
let _ = self.maybe_connect_remote().await;
|
|
|
|
let _ = self.send_command(RCCommand::Channel(
|
|
channel_id,
|
|
ChannelOperation::RequestShell,
|
|
));
|
|
|
|
self.start_terminal_recording(channel_id, format!("shell-channel-{}", server_channel_id.0))
|
|
.await;
|
|
|
|
info!(%channel_id, "Opening shell");
|
|
|
|
let _ = self
|
|
.session_handle
|
|
.as_mut()
|
|
.context("Invalid session state")?
|
|
.channel_success(server_channel_id.0)
|
|
.await;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn _channel_shell_request_begin(
|
|
&mut self,
|
|
server_channel_id: ServerChannelId,
|
|
) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
let _ = self.maybe_connect_remote().await;
|
|
self.send_command_and_wait(RCCommand::Channel(
|
|
channel_id,
|
|
ChannelOperation::RequestShell,
|
|
))
|
|
.await
|
|
.map_err(anyhow::Error::from)
|
|
}
|
|
|
|
pub async fn _channel_shell_request_finish(
|
|
&mut self,
|
|
server_channel_id: ServerChannelId,
|
|
) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
self.start_terminal_recording(channel_id, format!("shell-channel-{}", server_channel_id.0))
|
|
.await;
|
|
|
|
info!(%channel_id, "Opening shell");
|
|
let session = self
|
|
.session_handle
|
|
.clone()
|
|
.context("Invalid session state")?;
|
|
tokio::spawn(async move { session.channel_success(server_channel_id.0).await });
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn _channel_subsystem_request(
|
|
&mut self,
|
|
server_channel_id: ServerChannelId,
|
|
name: String,
|
|
) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
info!(channel=%channel_id, "Requesting subsystem {}", &name);
|
|
let _ = self.maybe_connect_remote().await;
|
|
self.send_command_and_wait(RCCommand::Channel(
|
|
channel_id,
|
|
ChannelOperation::RequestSubsystem(name),
|
|
))
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn _data(&mut self, server_channel_id: ServerChannelId, data: Bytes) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
debug!(channel=%server_channel_id.0, ?data, "Data");
|
|
if self.rc_state == RCState::Connecting && data.first() == Some(&3) {
|
|
info!(channel=%channel_id, "User requested connection abort (Ctrl-C)");
|
|
self.request_disconnect().await;
|
|
return Ok(());
|
|
}
|
|
|
|
if let Some(recorder) = self.channel_recorders.get_mut(&channel_id) {
|
|
if let Err(error) = recorder
|
|
.write(TerminalRecordingStreamId::Input, &data)
|
|
.await
|
|
{
|
|
error!(channel=%channel_id, ?error, "Failed to record terminal data");
|
|
self.channel_recorders.remove(&channel_id);
|
|
}
|
|
}
|
|
|
|
if let Some(recorder) = self.traffic_connection_recorders.get_mut(&channel_id) {
|
|
if let Err(error) = recorder.write_tx(&data).await {
|
|
error!(channel=%channel_id, ?error, "Failed to record traffic data");
|
|
self.traffic_connection_recorders.remove(&channel_id);
|
|
}
|
|
}
|
|
|
|
if self.pty_channels.contains(&channel_id) {
|
|
let _ = self
|
|
.event_sender
|
|
.send_once(Event::ConsoleInput(data.clone()))
|
|
.await;
|
|
}
|
|
|
|
let _ = self.send_command(RCCommand::Channel(channel_id, ChannelOperation::Data(data)));
|
|
Ok(())
|
|
}
|
|
|
|
async fn _extended_data(
|
|
&mut self,
|
|
server_channel_id: ServerChannelId,
|
|
code: u32,
|
|
data: Bytes,
|
|
) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
debug!(channel=%server_channel_id.0, ?data, "Data");
|
|
let _ = self.send_command_and_wait(RCCommand::Channel(
|
|
channel_id,
|
|
ChannelOperation::ExtendedData { ext: code, data },
|
|
));
|
|
Ok(())
|
|
}
|
|
|
|
async fn _tcpip_forward(&mut self, address: String, port: u32) -> Result<()> {
|
|
info!(%address, %port, "Remote port forwarding requested");
|
|
let _ = self.maybe_connect_remote().await;
|
|
self.send_command_and_wait(RCCommand::ForwardTCPIP(address, port))
|
|
.await
|
|
.map_err(anyhow::Error::from)
|
|
}
|
|
|
|
pub async fn _cancel_tcpip_forward(&mut self, address: String, port: u32) -> Result<()> {
|
|
info!(%address, %port, "Remote port forwarding cancelled");
|
|
self.send_command_and_wait(RCCommand::CancelTCPIPForward(address, port))
|
|
.await
|
|
.map_err(anyhow::Error::from)
|
|
}
|
|
|
|
async fn _auth_publickey(
|
|
&mut self,
|
|
ssh_username: Secret<String>,
|
|
key: PublicKey,
|
|
) -> russh::server::Auth {
|
|
let selector: AuthSelector = ssh_username.expose_secret().into();
|
|
|
|
info!(
|
|
"Public key auth as {:?} with key {}",
|
|
selector,
|
|
key.public_key_base64()
|
|
);
|
|
|
|
match self
|
|
.try_auth(
|
|
&selector,
|
|
Some(AuthCredential::PublicKey {
|
|
kind: key.name().to_string(),
|
|
public_key_bytes: Bytes::from(key.public_key_bytes()),
|
|
}),
|
|
)
|
|
.await
|
|
{
|
|
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
|
Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
|
|
proceed_with_methods: Some(MethodSet::all()),
|
|
},
|
|
Ok(AuthResult::Need(kinds)) => russh::server::Auth::Reject {
|
|
proceed_with_methods: Some(self.get_remaining_auth_methods(kinds)),
|
|
},
|
|
Err(error) => {
|
|
error!(?error, "Failed to verify credentials");
|
|
russh::server::Auth::Reject {
|
|
proceed_with_methods: None,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn _auth_password(
|
|
&mut self,
|
|
ssh_username: Secret<String>,
|
|
password: Secret<String>,
|
|
) -> russh::server::Auth {
|
|
let selector: AuthSelector = ssh_username.expose_secret().into();
|
|
info!("Password auth as {:?}", selector);
|
|
|
|
match self
|
|
.try_auth(&selector, Some(AuthCredential::Password(password)))
|
|
.await
|
|
{
|
|
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
|
Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
|
|
proceed_with_methods: None,
|
|
},
|
|
Ok(AuthResult::Need(_)) => russh::server::Auth::Reject {
|
|
proceed_with_methods: None,
|
|
},
|
|
Err(error) => {
|
|
error!(?error, "Failed to verify credentials");
|
|
russh::server::Auth::Reject {
|
|
proceed_with_methods: None,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn _auth_keyboard_interactive(
|
|
&mut self,
|
|
ssh_username: Secret<String>,
|
|
response: Option<Secret<String>>,
|
|
) -> russh::server::Auth {
|
|
let selector: AuthSelector = ssh_username.expose_secret().into();
|
|
info!("Keyboard-interactive auth as {:?}", selector);
|
|
|
|
let cred;
|
|
match &mut self.keyboard_interactive_state {
|
|
KeyboardInteractiveState::None => {
|
|
cred = None;
|
|
}
|
|
KeyboardInteractiveState::OtpRequested => {
|
|
cred = response.map(AuthCredential::Otp);
|
|
}
|
|
KeyboardInteractiveState::WebAuthRequested(event) => {
|
|
cred = None;
|
|
let _ = event.recv().await;
|
|
// the auth state has been updated by now
|
|
}
|
|
}
|
|
|
|
self.keyboard_interactive_state = KeyboardInteractiveState::None;
|
|
|
|
match self.try_auth(&selector, cred).await {
|
|
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
|
Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
|
|
proceed_with_methods: None,
|
|
},
|
|
Ok(AuthResult::Need(kinds)) => {
|
|
if kinds.contains(&CredentialKind::Totp) {
|
|
self.keyboard_interactive_state = KeyboardInteractiveState::OtpRequested;
|
|
russh::server::Auth::Partial {
|
|
name: Cow::Borrowed("Two-factor authentication"),
|
|
instructions: Cow::Borrowed(""),
|
|
prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]),
|
|
}
|
|
} else if kinds.contains(&CredentialKind::WebUserApproval) {
|
|
let Some(auth_state) = self.auth_state.as_ref() else {
|
|
return russh::server::Auth::Reject { proceed_with_methods: None};
|
|
};
|
|
let auth_state_id = *auth_state.lock().await.id();
|
|
let event = self
|
|
.services
|
|
.auth_state_store
|
|
.lock()
|
|
.await
|
|
.subscribe(auth_state_id);
|
|
self.keyboard_interactive_state =
|
|
KeyboardInteractiveState::WebAuthRequested(event);
|
|
|
|
let mut login_url = match self
|
|
.services
|
|
.config
|
|
.lock()
|
|
.await
|
|
.construct_external_url(None)
|
|
{
|
|
Ok(url) => url,
|
|
Err(error) => {
|
|
error!(?error, "Failed to construct external URL");
|
|
return russh::server::Auth::Reject {
|
|
proceed_with_methods: None,
|
|
};
|
|
}
|
|
};
|
|
|
|
login_url.set_path("@warpgate");
|
|
login_url.set_fragment(Some(&format!("/login/{auth_state_id}")));
|
|
|
|
russh::server::Auth::Partial {
|
|
name: Cow::Owned(format!(
|
|
concat!(
|
|
"----------------------------------------------------------------\n",
|
|
"Warpgate authentication: please open {} in your browser\n",
|
|
"----------------------------------------------------------------\n"
|
|
),
|
|
login_url
|
|
)),
|
|
instructions: Cow::Borrowed(""),
|
|
prompts: Cow::Owned(vec![(Cow::Borrowed("Press Enter when done: "), true)]),
|
|
}
|
|
} else {
|
|
russh::server::Auth::Reject {
|
|
proceed_with_methods: None,
|
|
}
|
|
}
|
|
}
|
|
Err(error) => {
|
|
error!(?error, "Failed to verify credentials");
|
|
russh::server::Auth::Reject {
|
|
proceed_with_methods: None,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn get_remaining_auth_methods(&self, kinds: HashSet<CredentialKind>) -> MethodSet {
|
|
let mut m = MethodSet::empty();
|
|
for kind in kinds {
|
|
match kind {
|
|
CredentialKind::Password => m.insert(MethodSet::PASSWORD),
|
|
CredentialKind::Totp => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
|
|
CredentialKind::WebUserApproval => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
|
|
CredentialKind::PublicKey => m.insert(MethodSet::PUBLICKEY),
|
|
CredentialKind::Sso => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
|
|
}
|
|
}
|
|
m
|
|
}
|
|
|
|
async fn try_auth(
|
|
&mut self,
|
|
selector: &AuthSelector,
|
|
credential: Option<AuthCredential>,
|
|
) -> Result<AuthResult> {
|
|
match selector {
|
|
AuthSelector::User {
|
|
username,
|
|
target_name,
|
|
} => {
|
|
let cp = self.services.config_provider.clone();
|
|
|
|
let state_arc = self.get_auth_state(username).await?;
|
|
let mut state = state_arc.lock().await;
|
|
|
|
if let Some(credential) = credential {
|
|
if cp
|
|
.lock()
|
|
.await
|
|
.validate_credential(username, &credential)
|
|
.await?
|
|
{
|
|
state.add_valid_credential(credential);
|
|
}
|
|
}
|
|
|
|
let user_auth_result = state.verify();
|
|
|
|
match user_auth_result {
|
|
AuthResult::Accepted { username } => {
|
|
self.services
|
|
.auth_state_store
|
|
.lock()
|
|
.await
|
|
.complete(state.id())
|
|
.await;
|
|
let target_auth_result = {
|
|
self.services
|
|
.config_provider
|
|
.lock()
|
|
.await
|
|
.authorize_target(&username, target_name)
|
|
.await?
|
|
};
|
|
if !target_auth_result {
|
|
warn!(
|
|
"Target {} not authorized for user {}",
|
|
target_name, username
|
|
);
|
|
return Ok(AuthResult::Rejected);
|
|
}
|
|
self._auth_accept(&username, target_name).await?;
|
|
Ok(AuthResult::Accepted { username })
|
|
}
|
|
x => Ok(x),
|
|
}
|
|
}
|
|
AuthSelector::Ticket { secret } => {
|
|
match authorize_ticket(&self.services.db, secret).await? {
|
|
Some(ticket) => {
|
|
info!("Authorized for {} with a ticket", ticket.target);
|
|
consume_ticket(&self.services.db, &ticket.id).await?;
|
|
self._auth_accept(&ticket.username, &ticket.target).await?;
|
|
Ok(AuthResult::Accepted {
|
|
username: ticket.username.clone(),
|
|
})
|
|
}
|
|
None => Ok(AuthResult::Rejected),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn _auth_accept(
|
|
&mut self,
|
|
username: &str,
|
|
target_name: &str,
|
|
) -> Result<(), WarpgateError> {
|
|
info!(%username, "Authenticated");
|
|
|
|
let _ = self
|
|
.server_handle
|
|
.lock()
|
|
.await
|
|
.set_username(username.to_string())
|
|
.await;
|
|
self.username = Some(username.to_string());
|
|
|
|
let target = {
|
|
self.services
|
|
.config_provider
|
|
.lock()
|
|
.await
|
|
.list_targets()
|
|
.await?
|
|
.iter()
|
|
.filter_map(|t| match t.options {
|
|
TargetOptions::Ssh(ref options) => Some((t, options)),
|
|
_ => None,
|
|
})
|
|
.find(|(t, _)| t.name == target_name)
|
|
.map(|(t, opt)| (t.clone(), opt.clone()))
|
|
};
|
|
|
|
let Some((target, mut ssh_options)) = target else {
|
|
self.target = TargetSelection::NotFound(target_name.to_string());
|
|
warn!("Selected target not found");
|
|
return Ok(());
|
|
};
|
|
|
|
// Forward username from the authenticated user to the target, if target has no username
|
|
if ssh_options.username.is_empty() {
|
|
ssh_options.username = username.to_string();
|
|
}
|
|
|
|
let _ = self.server_handle.lock().await.set_target(&target).await;
|
|
self.target = TargetSelection::Found(target, ssh_options);
|
|
Ok(())
|
|
}
|
|
|
|
async fn _channel_close(&mut self, server_channel_id: ServerChannelId) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
debug!(channel=%channel_id, "Closing channel");
|
|
self.send_command_and_wait(RCCommand::Channel(channel_id, ChannelOperation::Close))
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn _channel_eof(&mut self, server_channel_id: ServerChannelId) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
debug!(channel=%channel_id, "EOF");
|
|
let _ = self.send_command(RCCommand::Channel(channel_id, ChannelOperation::Eof));
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn _channel_signal(
|
|
&mut self,
|
|
server_channel_id: ServerChannelId,
|
|
signal: Sig,
|
|
) -> Result<()> {
|
|
let channel_id = self.map_channel(&server_channel_id)?;
|
|
debug!(channel=%channel_id, ?signal, "Signal");
|
|
self.send_command_and_wait(RCCommand::Channel(
|
|
channel_id,
|
|
ChannelOperation::Signal(signal),
|
|
))
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
fn send_command(&mut self, command: RCCommand) -> Result<(), RCCommand> {
|
|
self.rc_tx.send((command, None)).map_err(|e| e.0 .0)
|
|
}
|
|
|
|
async fn send_command_and_wait(&mut self, command: RCCommand) -> Result<(), SshClientError> {
|
|
let (tx, rx) = oneshot::channel();
|
|
let mut cmd = match self.rc_tx.send((command, Some(tx))) {
|
|
Ok(_) => PendingCommand::Waiting(rx),
|
|
Err(_) => PendingCommand::Failed,
|
|
};
|
|
|
|
loop {
|
|
tokio::select! {
|
|
result = &mut cmd => {
|
|
return result
|
|
}
|
|
event = self.get_next_event() => {
|
|
match event {
|
|
Some(event) => {
|
|
self.handle_event(event).await.map_err(SshClientError::from)?
|
|
}
|
|
None => {Err(SshClientError::MpscError)?}
|
|
};
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn _disconnect(&mut self) {
|
|
debug!("Client disconnect requested");
|
|
self.request_disconnect().await;
|
|
}
|
|
|
|
async fn request_disconnect(&mut self) {
|
|
debug!("Disconnecting");
|
|
let _ = self.rc_abort_tx.send(());
|
|
if self.rc_state != RCState::NotInitialized && self.rc_state != RCState::Disconnected {
|
|
let _ = self.send_command(RCCommand::Disconnect);
|
|
}
|
|
}
|
|
|
|
async fn disconnect_server(&mut self) {
|
|
let all_channels = std::mem::take(&mut self.all_channels);
|
|
let channels = all_channels
|
|
.into_iter()
|
|
.map(|x| self.map_channel_reverse(&x))
|
|
.filter_map(|x| x.ok())
|
|
.collect::<Vec<_>>();
|
|
|
|
let _ = self
|
|
.maybe_with_session(|handle| async move {
|
|
for ch in channels {
|
|
let _ = handle.close(ch.0).await;
|
|
}
|
|
Ok(())
|
|
})
|
|
.await;
|
|
|
|
self.session_handle = None;
|
|
}
|
|
}
|
|
|
|
impl Drop for ServerSession {
|
|
fn drop(&mut self) {
|
|
let _ = self.rc_abort_tx.send(());
|
|
info!("Closed session");
|
|
debug!("Dropped");
|
|
}
|
|
}
|
|
|
|
pub enum PendingCommand {
|
|
Waiting(oneshot::Receiver<Result<(), SshClientError>>),
|
|
Failed,
|
|
}
|
|
|
|
impl Future for PendingCommand {
|
|
type Output = Result<(), SshClientError>;
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
|
|
match self.get_mut() {
|
|
PendingCommand::Waiting(ref mut rx) => match Pin::new(rx).poll(cx) {
|
|
Poll::Ready(result) => {
|
|
Poll::Ready(result.unwrap_or(Err(SshClientError::MpscError)))
|
|
}
|
|
Poll::Pending => Poll::Pending,
|
|
},
|
|
PendingCommand::Failed => Poll::Ready(Err(SshClientError::MpscError)),
|
|
}
|
|
}
|
|
}
|