From c6885f18c362d2df42163cabdcc2964f121daf34 Mon Sep 17 00:00:00 2001 From: Eugene Date: Mon, 8 Aug 2022 23:30:18 +0200 Subject: [PATCH] Out-of-band SSO (#245) --- .github/workflows/build.yml | 2 - Cargo.lock | 6 +- rust-toolchain | 1 + rust-toolchain.toml | 2 - warpgate-common/Cargo.toml | 1 + warpgate-common/src/auth/cred.rs | 4 + warpgate-common/src/auth/policy.rs | 89 ++++++--- warpgate-common/src/auth/state.rs | 61 +++--- warpgate-common/src/auth/store.rs | 90 ++++++--- warpgate-common/src/config/mod.rs | 24 ++- warpgate-common/src/config_providers/file.rs | 57 +++++- warpgate-common/src/config_providers/mod.rs | 6 +- warpgate-common/src/error.rs | 8 + warpgate-protocol-http/Cargo.toml | 1 + warpgate-protocol-http/src/api/auth.rs | 181 ++++++++++++++--- .../src/api/sso_provider_detail.rs | 39 ++-- .../src/api/sso_provider_list.rs | 11 +- warpgate-protocol-http/src/common.rs | 10 +- warpgate-protocol-mysql/src/session.rs | 31 +-- warpgate-protocol-ssh/Cargo.toml | 2 +- warpgate-protocol-ssh/src/server/mod.rs | 1 + warpgate-protocol-ssh/src/server/session.rs | 183 ++++++++++++++---- warpgate-web/src/gateway/App.svelte | 50 +++-- warpgate-web/src/gateway/Login.svelte | 48 +++-- warpgate-web/src/gateway/OutOfBandAuth.svelte | 65 +++++++ .../src/gateway/lib/openapi-schema.json | 134 ++++++++++++- warpgate/src/commands/check.rs | 3 + warpgate/src/config.rs | 1 - 28 files changed, 863 insertions(+), 248 deletions(-) create mode 100644 rust-toolchain delete mode 100644 rust-toolchain.toml create mode 100644 warpgate-web/src/gateway/OutOfBandAuth.svelte diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3756fe0..9f03cc1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -12,7 +12,6 @@ jobs: - uses: actions-rs/toolchain@v1 with: - toolchain: nightly target: x86_64-unknown-linux-gnu override: true @@ -42,7 +41,6 @@ jobs: uses: actions-rs/cargo@v1 with: command: build - toolchain: nightly use-cross: true args: --release --target x86_64-unknown-linux-gnu -Ztarget-applies-to-host diff --git a/Cargo.lock b/Cargo.lock index 98edba0..c944c07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3143,9 +3143,9 @@ dependencies = [ [[package]] name = "russh" -version = "0.34.0-beta.7" +version = "0.34.0-beta.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f3bb72a66e32d52e0e258627d141d5c93b408e050f15033699caa836d064c7e" +checksum = "ccd8be93ee0b54a8a6b74c77ecef946185f0acbfb5234ea66666887621381e85" dependencies = [ "aes 0.8.1", "aes-gcm 0.10.1", @@ -4620,6 +4620,7 @@ dependencies = [ "bytes 1.2.1", "chrono", "data-encoding", + "futures", "humantime-serde", "lazy_static", "once_cell", @@ -4691,6 +4692,7 @@ version = "0.4.0" dependencies = [ "anyhow", "async-trait", + "chrono", "cookie", "data-encoding", "delegate", diff --git a/rust-toolchain b/rust-toolchain new file mode 100644 index 0000000..8350116 --- /dev/null +++ b/rust-toolchain @@ -0,0 +1 @@ +nightly-2022-08-01 diff --git a/rust-toolchain.toml b/rust-toolchain.toml deleted file mode 100644 index 718c034..0000000 --- a/rust-toolchain.toml +++ /dev/null @@ -1,2 +0,0 @@ -[toolchain] -channel = "nightly-2022-07-22" diff --git a/warpgate-common/Cargo.toml b/warpgate-common/Cargo.toml index bb17605..e4f5ef6 100644 --- a/warpgate-common/Cargo.toml +++ b/warpgate-common/Cargo.toml @@ -13,6 +13,7 @@ chrono = { version = "0.4", features = ["serde"] } data-encoding = "2.3" humantime-serde = "1.1" lazy_static = "1.4" +futures = "0.3" once_cell = "1.10" packet = "0.1" password-hash = "0.4" diff --git a/warpgate-common/src/auth/cred.rs b/warpgate-common/src/auth/cred.rs index 5e1badd..2548629 100644 --- a/warpgate-common/src/auth/cred.rs +++ b/warpgate-common/src/auth/cred.rs @@ -13,6 +13,8 @@ pub enum CredentialKind { Otp, #[serde(rename = "sso")] Sso, + #[serde(rename = "web")] + WebUserApproval, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -27,6 +29,7 @@ pub enum AuthCredential { provider: String, email: String, }, + WebUserApproval, } impl AuthCredential { @@ -36,6 +39,7 @@ impl AuthCredential { Self::PublicKey { .. } => CredentialKind::PublicKey, Self::Otp { .. } => CredentialKind::Otp, Self::Sso { .. } => CredentialKind::Sso, + Self::WebUserApproval => CredentialKind::WebUserApproval, } } } diff --git a/warpgate-common/src/auth/policy.rs b/warpgate-common/src/auth/policy.rs index d349a59..7d2df63 100644 --- a/warpgate-common/src/auth/policy.rs +++ b/warpgate-common/src/auth/policy.rs @@ -1,12 +1,10 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use super::{AuthCredential, CredentialKind}; -use crate::UserRequireCredentialsPolicy; pub enum CredentialPolicyResponse { Ok, - NeedMoreCredentials, - Need(CredentialKind), + Need(HashSet), } pub trait CredentialPolicy { @@ -17,36 +15,71 @@ pub trait CredentialPolicy { ) -> CredentialPolicyResponse; } -impl CredentialPolicy for UserRequireCredentialsPolicy { +pub struct AnySingleCredentialPolicy { + pub supported_credential_types: HashSet, +} + +pub struct AllCredentialsPolicy { + pub required_credential_types: HashSet, + pub supported_credential_types: HashSet, +} + +pub struct PerProtocolCredentialPolicy { + pub protocols: HashMap<&'static str, Box>, + pub default: Box, +} + +impl CredentialPolicy for AnySingleCredentialPolicy { fn is_sufficient( &self, - protocol: &str, + _protocol: &str, valid_credentials: &[AuthCredential], ) -> CredentialPolicyResponse { - let required_kinds = match protocol { - "SSH" => &self.ssh, - "HTTP" => &self.http, - "MySQL" => &self.mysql, - _ => unreachable!(), - }; - if let Some(required_kinds) = required_kinds { - let mut remaining_required_kinds = HashSet::::new(); - remaining_required_kinds.extend(required_kinds); - for kind in required_kinds { - if valid_credentials.iter().any(|x| x.kind() == *kind) { - remaining_required_kinds.remove(kind); - } - } - - if let Some(kind) = remaining_required_kinds.into_iter().next() { - CredentialPolicyResponse::Need(kind) - } else { - CredentialPolicyResponse::Ok - } - } else if valid_credentials.is_empty() { - CredentialPolicyResponse::NeedMoreCredentials + if valid_credentials.is_empty() { + CredentialPolicyResponse::Need( + self.supported_credential_types + .clone() + .into_iter() + .collect(), + ) } else { CredentialPolicyResponse::Ok } } } + +impl CredentialPolicy for AllCredentialsPolicy { + fn is_sufficient( + &self, + _protocol: &str, + valid_credentials: &[AuthCredential], + ) -> CredentialPolicyResponse { + let valid_credential_types: HashSet = + valid_credentials.iter().map(|x| x.kind()).collect(); + + if valid_credential_types.is_superset(&self.required_credential_types) { + CredentialPolicyResponse::Ok + } else { + CredentialPolicyResponse::Need( + self.required_credential_types + .difference(&valid_credential_types) + .cloned() + .collect(), + ) + } + } +} + +impl CredentialPolicy for PerProtocolCredentialPolicy { + fn is_sufficient( + &self, + protocol: &str, + valid_credentials: &[AuthCredential], + ) -> CredentialPolicyResponse { + if let Some(policy) = self.protocols.get(protocol) { + policy.is_sufficient(protocol, valid_credentials) + } else { + self.default.is_sufficient(protocol, valid_credentials) + } + } +} diff --git a/warpgate-common/src/auth/state.rs b/warpgate-common/src/auth/state.rs index 2ee9ea4..41617e9 100644 --- a/warpgate-common/src/auth/state.rs +++ b/warpgate-common/src/auth/state.rs @@ -1,71 +1,66 @@ -use std::time::{Duration, Instant}; - -use once_cell::sync::Lazy; -use tracing::warn; +use uuid::Uuid; use super::{AuthCredential, CredentialPolicy, CredentialPolicyResponse}; use crate::AuthResult; -#[allow(clippy::unwrap_used)] -pub static TIMEOUT: Lazy = Lazy::new(|| Duration::from_secs(60 * 10)); - pub struct AuthState { + id: Uuid, username: String, protocol: String, - policy: Option>, + force_rejected: bool, + policy: Box, valid_credentials: Vec, - started_at: Instant, } impl AuthState { - pub fn new( + pub(crate) fn new( + id: Uuid, username: String, protocol: String, - policy: Option>, + policy: Box, ) -> Self { Self { + id, username, protocol, + force_rejected: false, policy, valid_credentials: vec![], - started_at: Instant::now(), } } + pub fn id(&self) -> &Uuid { + &self.id + } + pub fn username(&self) -> &str { &self.username } + pub fn protocol(&self) -> &str { + &self.protocol + } + pub fn add_valid_credential(&mut self, credential: AuthCredential) { self.valid_credentials.push(credential); } - pub fn is_expired(&self) -> bool { - self.started_at.elapsed() > *TIMEOUT + pub fn reject(&mut self) { + self.force_rejected = true; } pub fn verify(&self) -> AuthResult { - if self.valid_credentials.is_empty() { - warn!( - username=%self.username, - "No matching valid credentials" - ); + if self.force_rejected { return AuthResult::Rejected; } - - if let Some(ref policy) = self.policy { - match policy.is_sufficient(&self.protocol, &self.valid_credentials[..]) { - CredentialPolicyResponse::Ok => {} - CredentialPolicyResponse::Need(kind) => { - return AuthResult::Need(kind); - } - CredentialPolicyResponse::NeedMoreCredentials => { - return AuthResult::Rejected; - } - } - } - AuthResult::Accepted { - username: self.username.clone(), + match self + .policy + .is_sufficient(&self.protocol, &self.valid_credentials[..]) + { + CredentialPolicyResponse::Ok => AuthResult::Accepted { + username: self.username.clone(), + }, + CredentialPolicyResponse::Need(kinds) => AuthResult::Need(kinds), } } } diff --git a/warpgate-common/src/auth/store.rs b/warpgate-common/src/auth/store.rs index 9d9a143..d3a8b8c 100644 --- a/warpgate-common/src/auth/store.rs +++ b/warpgate-common/src/auth/store.rs @@ -1,15 +1,32 @@ use std::collections::HashMap; use std::sync::Arc; +use std::time::{Duration, Instant}; -use tokio::sync::Mutex; +use once_cell::sync::Lazy; +use tokio::sync::{broadcast, Mutex}; use uuid::Uuid; use super::AuthState; -use crate::{ConfigProvider, WarpgateError}; +use crate::{AuthResult, ConfigProvider, WarpgateError}; + +#[allow(clippy::unwrap_used)] +pub static TIMEOUT: Lazy = Lazy::new(|| Duration::from_secs(60 * 10)); + +struct AuthCompletionSignal { + sender: broadcast::Sender, + created_at: Instant, +} + +impl AuthCompletionSignal { + pub fn is_expired(&self) -> bool { + self.created_at.elapsed() > *TIMEOUT + } +} pub struct AuthStateStore { config_provider: Arc>, - store: HashMap, + store: HashMap>, Instant)>, + completion_signals: HashMap, } impl AuthStateStore { @@ -17,47 +34,66 @@ impl AuthStateStore { Self { store: HashMap::new(), config_provider, + completion_signals: HashMap::new(), } } - pub fn contains_key(&mut self, id: &Uuid) -> bool { + pub fn contains_key(&self, id: &Uuid) -> bool { self.store.contains_key(id) } - pub fn get_mut(&mut self, id: &Uuid) -> Option<&mut AuthState> { - self.store.get_mut(id) + pub fn get(&self, id: &Uuid) -> Option>> { + self.store.get(id).map(|x| x.0.clone()) } pub async fn create( &mut self, username: &str, protocol: &str, - ) -> Result<(Uuid, &mut AuthState), WarpgateError> { + ) -> Result<(Uuid, Arc>), WarpgateError> { let id = Uuid::new_v4(); - let state = AuthState::new( - username.to_string(), - protocol.to_string(), - self.config_provider - .lock() - .await - .get_credential_policy(username) - .await?, - ); - self.store.insert(id, state); + let Some(policy) = self.config_provider + .lock() + .await + .get_credential_policy(username) + .await? else { + return Err(WarpgateError::UserNotFound) + }; + + let state = AuthState::new(id, username.to_string(), protocol.to_string(), policy); + self.store + .insert(id, (Arc::new(Mutex::new(state)), Instant::now())); #[allow(clippy::unwrap_used)] - Ok((id, self.store.get_mut(&id).unwrap())) + Ok((id, self.get(&id).unwrap())) + } + + pub fn subscribe(&mut self, id: Uuid) -> broadcast::Receiver { + let signal = self.completion_signals.entry(id).or_insert_with(|| { + let (sender, _) = broadcast::channel(1); + AuthCompletionSignal { + sender, + created_at: Instant::now(), + } + }); + + signal.sender.subscribe() + } + + pub async fn complete(&mut self, id: &Uuid) { + let Some((state, _)) = self.store.get(id) else { + return + }; + if let Some(sig) = self.completion_signals.remove(id) { + let _ = sig.sender.send(state.lock().await.verify()); + } } pub async fn vacuum(&mut self) { - let mut to_remove = vec![]; - for (id, state) in self.store.iter() { - if state.is_expired() { - to_remove.push(*id); - } - } - for id in to_remove { - self.store.remove(&id); - } + self.store + .retain(|_, (_, started_at)| started_at.elapsed() < *TIMEOUT); + + self.completion_signals + .retain(|_, signal| !signal.is_expired()); } } diff --git a/warpgate-common/src/config/mod.rs b/warpgate-common/src/config/mod.rs index bb3bd87..82e6081 100644 --- a/warpgate-common/src/config/mod.rs +++ b/warpgate-common/src/config/mod.rs @@ -5,11 +5,12 @@ use std::time::Duration; use poem_openapi::{Enum, Object, Union}; use serde::{Deserialize, Serialize}; +use url::Url; use warpgate_sso::SsoProviderConfig; use crate::auth::CredentialKind; use crate::helpers::otp::OtpSecretKey; -use crate::{ListenEndpoint, Secret}; +use crate::{ListenEndpoint, Secret, WarpgateError}; const fn _default_true() -> bool { true @@ -426,3 +427,24 @@ pub struct WarpgateConfig { pub store: WarpgateConfigStore, pub paths_relative_to: PathBuf, } + +impl WarpgateConfig { + pub fn construct_external_url( + &self, + fallback_host: Option<&str>, + ) -> Result { + let ext_host = self.store.external_host.as_deref().or(fallback_host); + let Some(ext_host) = ext_host else { + return Err(WarpgateError::ExternalHostNotSet); + }; + let ext_port = self.store.http.listen.port(); + + let mut url = Url::parse(&format!("https://{ext_host}/"))?; + + if ext_port != 443 { + let _ = url.set_port(Some(ext_port)); + } + + Ok(url) + } +} diff --git a/warpgate-common/src/config_providers/file.rs b/warpgate-common/src/config_providers/file.rs index b5eb0dc..6a87a86 100644 --- a/warpgate-common/src/config_providers/file.rs +++ b/warpgate-common/src/config_providers/file.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use async_trait::async_trait; @@ -11,7 +11,10 @@ use uuid::Uuid; use warpgate_db_entities::Ticket; use super::ConfigProvider; -use crate::auth::{AuthCredential, CredentialPolicy}; +use crate::auth::{ + AllCredentialsPolicy, AnySingleCredentialPolicy, AuthCredential, CredentialKind, + CredentialPolicy, PerProtocolCredentialPolicy, +}; use crate::helpers::hash::verify_password_hash; use crate::helpers::otp::verify_totp; use crate::{Target, User, UserAuthCredential, UserSnapshot, WarpgateConfig, WarpgateError}; @@ -78,9 +81,52 @@ impl ConfigProvider for FileConfigProvider { return Ok(None); }; - Ok(user - .require - .map(|r| Box::new(r) as Box)) + let supported_credential_types: HashSet = + user.credentials.iter().map(|x| x.kind()).collect(); + let default_policy = Box::new(AnySingleCredentialPolicy { + supported_credential_types: supported_credential_types.clone(), + }) as Box; + + if let Some(req) = user.require { + let mut policy = PerProtocolCredentialPolicy { + default: default_policy, + protocols: HashMap::new(), + }; + + if let Some(p) = req.http { + policy.protocols.insert( + "HTTP", + Box::new(AllCredentialsPolicy { + supported_credential_types: supported_credential_types.clone(), + required_credential_types: p.into_iter().collect(), + }), + ); + } + if let Some(p) = req.mysql { + policy.protocols.insert( + "MySQL", + Box::new(AllCredentialsPolicy { + supported_credential_types: supported_credential_types.clone(), + required_credential_types: p.into_iter().collect(), + }), + ); + } + if let Some(p) = req.ssh { + policy.protocols.insert( + "SSH", + Box::new(AllCredentialsPolicy { + supported_credential_types, + required_credential_types: p.into_iter().collect(), + }), + ); + } + + Ok(Some( + Box::new(policy) as Box + )) + } else { + Ok(Some(default_policy)) + } } async fn username_for_sso_credential( @@ -193,6 +239,7 @@ impl ConfigProvider for FileConfigProvider { } return Ok(false); } + _ => return Err(WarpgateError::InvalidCredentialType), } } diff --git a/warpgate-common/src/config_providers/mod.rs b/warpgate-common/src/config_providers/mod.rs index 3209d3f..f23d1c5 100644 --- a/warpgate-common/src/config_providers/mod.rs +++ b/warpgate-common/src/config_providers/mod.rs @@ -1,4 +1,5 @@ mod file; +use std::collections::HashSet; use std::sync::Arc; use async_trait::async_trait; @@ -12,11 +13,10 @@ use warpgate_db_entities::Ticket; use crate::auth::{AuthCredential, CredentialKind, CredentialPolicy}; use crate::{Secret, Target, UserSnapshot, WarpgateError}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum AuthResult { Accepted { username: String }, - Need(CredentialKind), - NeedMoreCredentials, + Need(HashSet), Rejected, } diff --git a/warpgate-common/src/error.rs b/warpgate-common/src/error.rs index f37cc95..1eeb26d 100644 --- a/warpgate-common/src/error.rs +++ b/warpgate-common/src/error.rs @@ -9,8 +9,16 @@ pub enum WarpgateError { DatabaseError(#[from] sea_orm::DbErr), #[error("ticket not found: {0}")] InvalidTicket(Uuid), + #[error("invalid credential type")] + InvalidCredentialType, #[error(transparent)] Other(Box), + #[error("user not found")] + UserNotFound, + #[error("failed to parse url: {0}")] + UrlParse(#[from] url::ParseError), + #[error("external_url config option is not set")] + ExternalHostNotSet, } impl ResponseError for WarpgateError { diff --git a/warpgate-protocol-http/Cargo.toml b/warpgate-protocol-http/Cargo.toml index 41293b1..c83819d 100644 --- a/warpgate-protocol-http/Cargo.toml +++ b/warpgate-protocol-http/Cargo.toml @@ -7,6 +7,7 @@ version = "0.4.0" [dependencies] anyhow = "1.0" async-trait = "0.1" +chrono = {version = "0.4", features = ["serde"]} cookie = "0.16" data-encoding = "2.3" delegate = "0.6" diff --git a/warpgate-protocol-http/src/api/auth.rs b/warpgate-protocol-http/src/api/auth.rs index 9a9e019..04d06a0 100644 --- a/warpgate-protocol-http/src/api/auth.rs +++ b/warpgate-protocol-http/src/api/auth.rs @@ -3,14 +3,18 @@ use std::sync::Arc; use poem::session::Session; use poem::web::Data; use poem::Request; +use poem_openapi::param::Path; use poem_openapi::payload::Json; use poem_openapi::{ApiResponse, Enum, Object, OpenApi}; use tokio::sync::Mutex; use tracing::*; -use warpgate_common::auth::{AuthCredential, CredentialKind}; +use uuid::Uuid; +use warpgate_common::auth::{AuthCredential, AuthState, CredentialKind}; use warpgate_common::{AuthResult, Secret, Services}; -use crate::common::{authorize_session, get_auth_state_for_request, SessionExt}; +use crate::common::{ + authorize_session, endpoint_auth, get_auth_state_for_request, SessionAuthorization, SessionExt, +}; use crate::session::SessionStore; pub struct Api; @@ -33,6 +37,8 @@ enum ApiAuthState { PasswordNeeded, OtpNeeded, SsoNeeded, + WebUserApprovalNeeded, + PublicKeyNeeded, Success, } @@ -58,6 +64,7 @@ enum LogoutResponse { #[derive(Object)] struct AuthStateResponseInternal { + pub protocol: String, pub state: ApiAuthState, } @@ -65,17 +72,22 @@ struct AuthStateResponseInternal { enum AuthStateResponse { #[oai(status = 200)] Ok(Json), + #[oai(status = 404)] + NotFound, } impl From for ApiAuthState { fn from(state: AuthResult) -> Self { match state { AuthResult::Rejected => ApiAuthState::Failed, - AuthResult::Need(CredentialKind::Password) => ApiAuthState::PasswordNeeded, - AuthResult::Need(CredentialKind::Otp) => ApiAuthState::OtpNeeded, - AuthResult::Need(CredentialKind::Sso) => ApiAuthState::SsoNeeded, - AuthResult::Need(CredentialKind::PublicKey) => ApiAuthState::Failed, - AuthResult::NeedMoreCredentials => ApiAuthState::Failed, + AuthResult::Need(kinds) => match kinds.iter().next() { + Some(CredentialKind::Password) => ApiAuthState::PasswordNeeded, + Some(CredentialKind::Otp) => ApiAuthState::OtpNeeded, + Some(CredentialKind::Sso) => ApiAuthState::SsoNeeded, + Some(CredentialKind::WebUserApproval) => ApiAuthState::WebUserApprovalNeeded, + Some(CredentialKind::PublicKey) => ApiAuthState::PublicKeyNeeded, + None => ApiAuthState::Failed, + }, AuthResult::Accepted { .. } => ApiAuthState::Success, } } @@ -92,8 +104,9 @@ impl Api { body: Json, ) -> poem::Result { let mut auth_state_store = services.auth_state_store.lock().await; - let state = + let state_arc = get_auth_state_for_request(&body.username, session, &mut auth_state_store).await?; + let mut state = state_arc.lock().await; let mut cp = services.config_provider.lock().await; @@ -107,6 +120,7 @@ impl Api { match state.verify() { AuthResult::Accepted { username } => { + auth_state_store.complete(state.id()).await; authorize_session(req, username).await?; Ok(LoginResponse::Success) } @@ -131,12 +145,14 @@ impl Api { let mut auth_state_store = services.auth_state_store.lock().await; - let Some(state) = state_id.and_then(|id| auth_state_store.get_mut(&id.0)) else { + let Some(state_arc) = state_id.and_then(|id| auth_state_store.get(&id.0)) else { return Ok(LoginResponse::Failure(Json(LoginFailureResponse { state: ApiAuthState::NotStarted, }))) }; + let mut state = state_arc.lock().await; + let mut cp = services.config_provider.lock().await; let otp_cred = AuthCredential::Otp(body.otp.clone().into()); @@ -146,6 +162,7 @@ impl Api { match state.verify() { AuthResult::Accepted { username } => { + auth_state_store.complete(state.id()).await; authorize_session(req, username).await?; Ok(LoginResponse::Success) } @@ -155,27 +172,6 @@ impl Api { } } - #[oai(path = "/auth/state", method = "get", operation_id = "getAuthState")] - async fn api_auth_state( - &self, - session: &Session, - services: Data<&Services>, - ) -> poem::Result { - let state_id = session.get_auth_state_id(); - - let mut auth_state_store = services.auth_state_store.lock().await; - - let Some(state) = state_id.and_then(|id| auth_state_store.get_mut(&id.0)) else { - return Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal { - state: ApiAuthState::NotStarted, - }))); - }; - - Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal { - state: state.verify().into(), - }))) - } - #[oai(path = "/auth/logout", method = "post", operation_id = "logout")] async fn api_auth_logout( &self, @@ -187,4 +183,129 @@ impl Api { info!("Logged out"); Ok(LogoutResponse::Success) } + + #[oai( + path = "/auth/state", + method = "get", + operation_id = "getDefaultAuthState" + )] + async fn api_default_auth_state( + &self, + session: &Session, + services: Data<&Services>, + ) -> poem::Result { + let Some(state_id) = session.get_auth_state_id() else { + return Ok(AuthStateResponse::NotFound) + }; + let store = services.auth_state_store.lock().await; + let Some(state_arc) = store.get(&state_id.0) else { + return Ok(AuthStateResponse::NotFound); + }; + serialize_auth_state_inner(state_arc).await + } + + #[oai( + path = "/auth/state/:id", + method = "get", + operation_id = "get_auth_state", + transform = "endpoint_auth" + )] + async fn api_auth_state( + &self, + services: Data<&Services>, + auth: Option>, + id: Path, + ) -> poem::Result { + let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else { + return Ok(AuthStateResponse::NotFound); + }; + serialize_auth_state_inner(state_arc).await + } + + #[oai( + path = "/auth/state/:id/approve", + method = "post", + operation_id = "approve_auth", + transform = "endpoint_auth" + )] + async fn api_approve_auth( + &self, + services: Data<&Services>, + auth: Option>, + id: Path, + ) -> poem::Result { + let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else { + return Ok(AuthStateResponse::NotFound); + }; + + let auth_result = { + let mut state = state_arc.lock().await; + state.add_valid_credential(AuthCredential::WebUserApproval); + state.verify() + }; + + if let AuthResult::Accepted { .. } = auth_result { + services.auth_state_store.lock().await.complete(&*id).await; + } + serialize_auth_state_inner(state_arc).await + } + + #[oai( + path = "/auth/state/:id/reject", + method = "post", + operation_id = "reject_auth", + transform = "endpoint_auth" + )] + async fn api_reject_auth( + &self, + services: Data<&Services>, + auth: Option>, + id: Path, + ) -> poem::Result { + let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else { + return Ok(AuthStateResponse::NotFound); + }; + state_arc.lock().await.reject(); + services.auth_state_store.lock().await.complete(&*id).await; + serialize_auth_state_inner(state_arc).await + } +} + +async fn get_auth_state( + id: &Uuid, + services: &Services, + auth: Option<&SessionAuthorization>, +) -> Option>> { + let store = services.auth_state_store.lock().await; + + let Some(auth) = auth else { + return None; + }; + + let SessionAuthorization::User(username) = auth else { + return None; + }; + + let Some(state_arc) = store.get(&*id) else { + return None; + }; + + { + let state = state_arc.lock().await; + if state.username() != username { + return None; + } + } + + Some(state_arc) +} + +async fn serialize_auth_state_inner( + state_arc: Arc>, +) -> poem::Result { + let state = state_arc.lock().await; + Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal { + protocol: state.protocol().to_string(), + state: state.verify().into(), + }))) } diff --git a/warpgate-protocol-http/src/api/sso_provider_detail.rs b/warpgate-protocol-http/src/api/sso_provider_detail.rs index 9b4f106..ee64d58 100644 --- a/warpgate-protocol-http/src/api/sso_provider_detail.rs +++ b/warpgate-protocol-http/src/api/sso_provider_detail.rs @@ -1,10 +1,9 @@ use poem::session::Session; use poem::web::Data; use poem::Request; -use poem_openapi::param::Path; +use poem_openapi::param::{Path, Query}; use poem_openapi::payload::Json; use poem_openapi::{ApiResponse, Object, OpenApi}; -use reqwest::Url; use serde::{Deserialize, Serialize}; use warpgate_common::Services; use warpgate_sso::{SsoClient, SsoLoginRequest}; @@ -31,6 +30,7 @@ pub static SSO_CONTEXT_SESSION_KEY: &str = "sso_request"; pub struct SsoContext { pub provider: String, pub request: SsoLoginRequest, + pub next_url: Option, } #[OpenApi] @@ -46,31 +46,14 @@ impl Api { session: &Session, services: Data<&Services>, name: Path, + next: Query>, ) -> poem::Result { let config = services.config.lock().await; let name = name.0; - let ext_host = config - .store - .external_host - .as_deref() - .or_else(|| req.original_uri().host()); - let Some(ext_host) = ext_host else { - return Err(poem::Error::from_string("external_host config option is required for SSO", http::status::StatusCode::INTERNAL_SERVER_ERROR)); - }; - let ext_port = config.store.http.listen.port(); - let mut return_url = Url::parse(&format!("https://{ext_host}/@warpgate/api/sso/return")) - .map_err(|e| { - poem::Error::from_string( - format!("failed to construct the return URL: {e}"), - http::status::StatusCode::INTERNAL_SERVER_ERROR, - ) - })?; - - if ext_port != 443 { - let _ = return_url.set_port(Some(ext_port)); - } + let mut return_url = config.construct_external_url(req.original_uri().host())?; + return_url.set_path("@warpgate/api/sso/return"); let Some(provider_config) = config.store.sso_providers.iter().find(|p| p.name == *name) else { return Ok(StartSsoResponse::NotFound); @@ -84,10 +67,14 @@ impl Api { .map_err(poem::error::InternalServerError)?; let url = sso_req.auth_url().to_string(); - session.set(SSO_CONTEXT_SESSION_KEY, SsoContext { - provider: name, - request: sso_req, - }); + session.set( + SSO_CONTEXT_SESSION_KEY, + SsoContext { + provider: name, + request: sso_req, + next_url: next.0.clone(), + }, + ); Ok(StartSsoResponse::Ok(Json(StartSsoResponseParams { url }))) } diff --git a/warpgate-protocol-http/src/api/sso_provider_list.rs b/warpgate-protocol-http/src/api/sso_provider_list.rs index 3a218af..27524cc 100644 --- a/warpgate-protocol-http/src/api/sso_provider_list.rs +++ b/warpgate-protocol-http/src/api/sso_provider_list.rs @@ -120,8 +120,9 @@ impl Api { }; let mut auth_state_store = services.auth_state_store.lock().await; - let state = get_auth_state_for_request(&username, session, &mut auth_state_store).await?; + let state_arc = get_auth_state_for_request(&username, session, &mut auth_state_store).await?; + let mut state = state_arc.lock().await; let mut cp = services.config_provider.lock().await; if cp.validate_credential(&username, &cred).await? { @@ -130,11 +131,15 @@ impl Api { match state.verify() { AuthResult::Accepted { username } => { + auth_state_store.complete(state.id()).await; authorize_session(req, username).await?; } - _ => () + _ => (), } - Ok(Response::new(ReturnToSsoResponse::Ok).header("Location", "/@warpgate")) + Ok(Response::new(ReturnToSsoResponse::Ok).header( + "Location", + context.next_url.as_deref().unwrap_or("/@warpgate#/login"), + )) } } diff --git a/warpgate-protocol-http/src/common.rs b/warpgate-protocol-http/src/common.rs index c6ef118..c767ab8 100644 --- a/warpgate-protocol-http/src/common.rs +++ b/warpgate-protocol-http/src/common.rs @@ -170,18 +170,18 @@ pub fn gateway_redirect(req: &Request) -> Response { .unwrap_or("".into()); let path = format!( - "/@warpgate?next={}", + "/@warpgate#/login?next={}", utf8_percent_encode(&path, NON_ALPHANUMERIC), ); Redirect::temporary(path).into_response() } -pub async fn get_auth_state_for_request<'a>( +pub async fn get_auth_state_for_request( username: &str, session: &Session, - store: &'a mut AuthStateStore, -) -> Result<&'a mut AuthState, WarpgateError> { + store: &mut AuthStateStore, +) -> Result>, WarpgateError> { match session.get_auth_state_id() { Some(id) => { if !store.contains_key(&id.0) { @@ -192,7 +192,7 @@ pub async fn get_auth_state_for_request<'a>( }; match session.get_auth_state_id() { - Some(id) => Ok(store.get_mut(&id.0).unwrap()), + Some(id) => Ok(store.get(&id.0).unwrap()), None => { let (id, state) = store .create(&username, crate::common::PROTOCOL_NAME) diff --git a/warpgate-protocol-mysql/src/session.rs b/warpgate-protocol-mysql/src/session.rs index 54a4090..537746b 100644 --- a/warpgate-protocol-mysql/src/session.rs +++ b/warpgate-protocol-mysql/src/session.rs @@ -7,7 +7,7 @@ use tokio::net::TcpStream; use tokio::sync::Mutex; use tracing::*; use uuid::Uuid; -use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState}; +use warpgate_common::auth::{AuthCredential, AuthSelector}; use warpgate_common::helpers::rng::get_crypto_rng; use warpgate_common::{ authorize_ticket, AuthResult, Secret, Services, TargetMySqlOptions, TargetOptions, @@ -180,15 +180,20 @@ impl MySqlSession { username, target_name, } => { - let user_auth_result = { - let mut cp = self.services.config_provider.lock().await; + let state_arc = self + .services + .auth_state_store + .lock() + .await + .create(&username, crate::common::PROTOCOL_NAME) + .await? + .1; + let mut state = state_arc.lock().await; + let user_auth_result = { let credential = AuthCredential::Password(password); - let mut state = AuthState::new( - username.clone(), - crate::common::PROTOCOL_NAME.to_string(), - cp.get_credential_policy(&username).await?, - ); + + let mut cp = self.services.config_provider.lock().await; if cp.validate_credential(&username, &credential).await? { state.add_valid_credential(credential); } @@ -198,6 +203,12 @@ impl MySqlSession { match user_auth_result { AuthResult::Accepted { username } => { + self.services + .auth_state_store + .lock() + .await + .complete(state.id()) + .await; let target_auth_result = { self.services .config_provider @@ -216,9 +227,7 @@ impl MySqlSession { } self.run_authorized(handshake, username, target_name).await } - AuthResult::Rejected - | AuthResult::Need(_) - | AuthResult::NeedMoreCredentials => fail(&mut self).await, // TODO SSO + AuthResult::Rejected | AuthResult::Need(_) => fail(&mut self).await, // TODO SSO } } AuthSelector::Ticket { secret } => { diff --git a/warpgate-protocol-ssh/Cargo.toml b/warpgate-protocol-ssh/Cargo.toml index 6dfb664..c12abb1 100644 --- a/warpgate-protocol-ssh/Cargo.toml +++ b/warpgate-protocol-ssh/Cargo.toml @@ -12,7 +12,7 @@ bimap = "0.6" bytes = "1.2" dialoguer = "0.10" futures = "0.3" -russh = {version = "0.34.0-beta.7", features = ["openssl"]} +russh = {version = "0.34.0-beta.8", features = ["openssl"]} russh-keys = {version = "0.22.0-beta.4", features = ["openssl"]} sea-orm = {version = "^0.9", features = ["runtime-tokio-native-tls"], default-features = false} thiserror = "1.0" diff --git a/warpgate-protocol-ssh/src/server/mod.rs b/warpgate-protocol-ssh/src/server/mod.rs index ac6699b..0c5a1b8 100644 --- a/warpgate-protocol-ssh/src/server/mod.rs +++ b/warpgate-protocol-ssh/src/server/mod.rs @@ -23,6 +23,7 @@ pub async fn run_server(services: Services, address: SocketAddr) -> Result<()> { let config = services.config.lock().await; russh::server::Config { auth_rejection_time: std::time::Duration::from_secs(1), + connection_timeout: Some(std::time::Duration::from_secs(300)), methods: MethodSet::PUBLICKEY | MethodSet::PASSWORD | MethodSet::KEYBOARD_INTERACTIVE, keys: load_host_keys(&config)?, ..Default::default() diff --git a/warpgate-protocol-ssh/src/server/session.rs b/warpgate-protocol-ssh/src/server/session.rs index 52ca34a..804c542 100644 --- a/warpgate-protocol-ssh/src/server/session.rs +++ b/warpgate-protocol-ssh/src/server/session.rs @@ -1,6 +1,6 @@ use std::borrow::Cow; use std::collections::hash_map::Entry::Vacant; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::net::{Ipv4Addr, SocketAddr}; use std::str::FromStr; use std::sync::Arc; @@ -10,11 +10,11 @@ use anyhow::{Context, Result}; use bimap::BiMap; use bytes::{Bytes, BytesMut}; use russh::server::Session; -use russh::{CryptoVec, Sig}; +use russh::{CryptoVec, MethodSet, Sig}; use russh_keys::key::PublicKey; use russh_keys::PublicKeyBase64; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; -use tokio::sync::{oneshot, Mutex}; +use tokio::sync::{broadcast, oneshot, Mutex}; use tracing::*; use uuid::Uuid; use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState, CredentialKind}; @@ -53,6 +53,12 @@ enum Event { Client(RCEvent), } +enum KeyboardInteractiveState { + None, + OtpRequested, + WebAuthRequested(broadcast::Receiver), +} + pub struct ServerSession { pub id: SessionId, username: Option, @@ -74,7 +80,8 @@ pub struct ServerSession { hub: EventHub, event_sender: EventSender, service_output: ServiceOutput, - auth_state: Option, + auth_state: Option>>, + keyboard_interactive_state: KeyboardInteractiveState, } fn session_debug_tag(id: &SessionId, remote_address: &SocketAddr) -> String { @@ -142,6 +149,7 @@ impl ServerSession { so_tx.send(BytesMut::from(data).freeze()).context("x") })), auth_state: None, + keyboard_interactive_state: KeyboardInteractiveState::None, }; let this = Arc::new(Mutex::new(this)); @@ -217,18 +225,23 @@ impl ServerSession { Ok(this) } - async fn get_auth_state(&mut self, username: &str) -> Result<&mut AuthState> { + async fn get_auth_state(&mut self, username: &str) -> Result>> { #[allow(clippy::unwrap_used)] - if self.auth_state.is_none() || self.auth_state.as_ref().unwrap().username() != username { - let mut cp = self.services.config_provider.lock().await; - self.auth_state = Some(AuthState::new( - username.to_string(), - crate::PROTOCOL_NAME.to_string(), - cp.get_credential_policy(username).await?, - )); + if self.auth_state.is_none() + || self.auth_state.as_ref().unwrap().lock().await.username() != username + { + let state = self + .services + .auth_state_store + .lock() + .await + .create(username, crate::PROTOCOL_NAME) + .await? + .1; + self.auth_state = Some(state); } #[allow(clippy::unwrap_used)] - Ok(self.auth_state.as_mut().unwrap()) + Ok(self.auth_state.as_ref().map(Clone::clone).unwrap()) } pub fn make_logging_span(&self) -> tracing::Span { @@ -939,13 +952,17 @@ impl ServerSession { .await { Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept, - Ok(AuthResult::Rejected) => russh::server::Auth::Reject, - Ok(AuthResult::Need(_) | AuthResult::NeedMoreCredentials) => { - russh::server::Auth::Reject - } + Ok(AuthResult::Rejected) => russh::server::Auth::Reject { + proceed_with_methods: Some(MethodSet::all()), + }, + Ok(AuthResult::Need(kinds)) => russh::server::Auth::Reject { + proceed_with_methods: Some(self.get_remaining_auth_methods(kinds)), + }, Err(error) => { error!(?error, "Failed to verify credentials"); - russh::server::Auth::Reject + russh::server::Auth::Reject { + proceed_with_methods: None, + } } } } @@ -963,13 +980,17 @@ impl ServerSession { .await { Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept, - Ok(AuthResult::Rejected) => russh::server::Auth::Reject, - Ok(AuthResult::Need(_) | AuthResult::NeedMoreCredentials) => { - russh::server::Auth::Reject - } + Ok(AuthResult::Rejected) => russh::server::Auth::Reject { + proceed_with_methods: None, + }, + Ok(AuthResult::Need(_)) => russh::server::Auth::Reject { + proceed_with_methods: None, + }, Err(error) => { error!(?error, "Failed to verify credentials"); - russh::server::Auth::Reject + russh::server::Auth::Reject { + proceed_with_methods: None, + } } } } @@ -982,27 +1003,111 @@ impl ServerSession { let selector: AuthSelector = ssh_username.expose_secret().into(); info!("Keyboard-interactive auth as {:?}", selector); - let cred = response.map(AuthCredential::Otp); + let cred; + match &mut self.keyboard_interactive_state { + KeyboardInteractiveState::None => { + cred = None; + } + KeyboardInteractiveState::OtpRequested => { + cred = response.map(AuthCredential::Otp); + } + KeyboardInteractiveState::WebAuthRequested(event) => { + cred = None; + let _ = event.recv().await; + // the auth state has been updated by now + } + } + + self.keyboard_interactive_state = KeyboardInteractiveState::None; match self.try_auth(&selector, cred).await { Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept, - Ok( - AuthResult::Rejected - | AuthResult::NeedMoreCredentials - | AuthResult::Need(CredentialKind::Otp), - ) => russh::server::Auth::Partial { - name: Cow::Borrowed("Two-factor authentication"), - instructions: Cow::Borrowed(""), - prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]), + Ok(AuthResult::Rejected) => russh::server::Auth::Reject { + proceed_with_methods: None, }, - Ok(AuthResult::Need(_)) => russh::server::Auth::Reject, // TODO SSO + Ok(AuthResult::Need(kinds)) => { + if kinds.contains(&CredentialKind::Otp) { + self.keyboard_interactive_state = KeyboardInteractiveState::OtpRequested; + russh::server::Auth::Partial { + name: Cow::Borrowed("Two-factor authentication"), + instructions: Cow::Borrowed(""), + prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]), + } + } else if kinds.contains(&CredentialKind::WebUserApproval) { + let Some(auth_state) = self.auth_state.as_ref() else { + return russh::server::Auth::Reject { proceed_with_methods: None}; + }; + let auth_state_id = *auth_state.lock().await.id(); + let event = self + .services + .auth_state_store + .lock() + .await + .subscribe(auth_state_id); + self.keyboard_interactive_state = + KeyboardInteractiveState::WebAuthRequested(event); + + let mut login_url = match self + .services + .config + .lock() + .await + .construct_external_url(None) + { + Ok(url) => url, + Err(error) => { + error!(?error, "Failed to construct external URL"); + return russh::server::Auth::Reject { + proceed_with_methods: None, + }; + } + }; + + login_url.set_path("@warpgate"); + login_url + .set_fragment(Some(&format!("/login?next=%2Flogin%2F{auth_state_id}"))); + + russh::server::Auth::Partial { + name: Cow::Owned(format!( + concat!( + "----------------------------------------------------------------\n", + "Warpgate authentication: please open {} in your browser\n", + "----------------------------------------------------------------\n" + ), + login_url + )), + instructions: Cow::Borrowed(""), + prompts: Cow::Owned(vec![(Cow::Borrowed("Press Enter when done: "), true)]), + } + } else { + russh::server::Auth::Reject { + proceed_with_methods: None, + } + } + } Err(error) => { error!(?error, "Failed to verify credentials"); - russh::server::Auth::Reject + russh::server::Auth::Reject { + proceed_with_methods: None, + } } } } + fn get_remaining_auth_methods(&self, kinds: HashSet) -> MethodSet { + let mut m = MethodSet::empty(); + for kind in kinds { + match kind { + CredentialKind::Password => m.insert(MethodSet::PASSWORD), + CredentialKind::Otp => m.insert(MethodSet::KEYBOARD_INTERACTIVE), + CredentialKind::WebUserApproval => m.insert(MethodSet::KEYBOARD_INTERACTIVE), + CredentialKind::PublicKey => m.insert(MethodSet::PUBLICKEY), + CredentialKind::Sso => m.insert(MethodSet::KEYBOARD_INTERACTIVE), + } + } + m + } + async fn try_auth( &mut self, selector: &AuthSelector, @@ -1014,7 +1119,9 @@ impl ServerSession { target_name, } => { let cp = self.services.config_provider.clone(); - let state = self.get_auth_state(username).await?; + + let state_arc = self.get_auth_state(username).await?; + let mut state = state_arc.lock().await; if let Some(credential) = credential { if cp @@ -1031,6 +1138,12 @@ impl ServerSession { match user_auth_result { AuthResult::Accepted { username } => { + self.services + .auth_state_store + .lock() + .await + .complete(state.id()) + .await; let target_auth_result = { self.services .config_provider diff --git a/warpgate-web/src/gateway/App.svelte b/warpgate-web/src/gateway/App.svelte index a59c834..7d4e016 100644 --- a/warpgate-web/src/gateway/App.svelte +++ b/warpgate-web/src/gateway/App.svelte @@ -2,22 +2,25 @@ import { faSignOut } from '@fortawesome/free-solid-svg-icons' import { Alert, Spinner } from 'sveltestrap' import Fa from 'svelte-fa' +import Router, { push } from 'svelte-spa-router' +import { wrap } from 'svelte-spa-router/wrap' +import { get } from 'svelte/store' import { api } from 'gateway/lib/api' import { reloadServerInfo, serverInfo } from 'gateway/lib/store' import ThemeSwitcher from 'common/ThemeSwitcher.svelte' -import Login from './Login.svelte' -import TargetList from './TargetList.svelte' import Logo from 'common/Logo.svelte' let redirecting = false +let serverInfoPromise = reloadServerInfo() async function init () { - await reloadServerInfo() + await serverInfoPromise } async function logout () { await api.logout() await reloadServerInfo() + push('/login') } function onPageResume () { @@ -25,6 +28,36 @@ function onPageResume () { init() } +async function requireLogin (detail) { + await serverInfoPromise + if (!get(serverInfo)?.username) { + let url = detail.location + if (detail.querystring) { + url += '?' + detail.querystring + } + push('/login?next=' + encodeURIComponent(url)) + return false + } + return true +} + +const routes = { + '/': wrap({ + asyncComponent: () => import('./TargetList.svelte'), + props: { + 'on:navigation': () => redirecting = true, + }, + conditions: [requireLogin], + }), + '/login': wrap({ + asyncComponent: () => import('./Login.svelte'), + }), + '/login/:stateId': wrap({ + asyncComponent: () => import('./OutOfBandAuth.svelte'), + conditions: [requireLogin], + }), +} + init() @@ -38,9 +71,9 @@ init() {:else}
- {#if $serverInfo?.username}
@@ -56,12 +89,7 @@ init()
- {#if $serverInfo?.username} - redirecting = true} /> - {:else} - - {/if} +