diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9f03cc1..3756fe0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -12,6 +12,7 @@ jobs: - uses: actions-rs/toolchain@v1 with: + toolchain: nightly target: x86_64-unknown-linux-gnu override: true @@ -41,6 +42,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: build + toolchain: nightly use-cross: true args: --release --target x86_64-unknown-linux-gnu -Ztarget-applies-to-host diff --git a/Cargo.lock b/Cargo.lock index c944c07..98edba0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3143,9 +3143,9 @@ dependencies = [ [[package]] name = "russh" -version = "0.34.0-beta.8" +version = "0.34.0-beta.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccd8be93ee0b54a8a6b74c77ecef946185f0acbfb5234ea66666887621381e85" +checksum = "4f3bb72a66e32d52e0e258627d141d5c93b408e050f15033699caa836d064c7e" dependencies = [ "aes 0.8.1", "aes-gcm 0.10.1", @@ -4620,7 +4620,6 @@ dependencies = [ "bytes 1.2.1", "chrono", "data-encoding", - "futures", "humantime-serde", "lazy_static", "once_cell", @@ -4692,7 +4691,6 @@ version = "0.4.0" dependencies = [ "anyhow", "async-trait", - "chrono", "cookie", "data-encoding", "delegate", diff --git a/rust-toolchain b/rust-toolchain deleted file mode 100644 index 8350116..0000000 --- a/rust-toolchain +++ /dev/null @@ -1 +0,0 @@ -nightly-2022-08-01 diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..718c034 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly-2022-07-22" diff --git a/warpgate-common/Cargo.toml b/warpgate-common/Cargo.toml index e4f5ef6..bb17605 100644 --- a/warpgate-common/Cargo.toml +++ b/warpgate-common/Cargo.toml @@ -13,7 +13,6 @@ chrono = { version = "0.4", features = ["serde"] } data-encoding = "2.3" humantime-serde = "1.1" lazy_static = "1.4" -futures = "0.3" once_cell = "1.10" packet = "0.1" password-hash = "0.4" diff --git a/warpgate-common/src/auth/cred.rs b/warpgate-common/src/auth/cred.rs index 2548629..5e1badd 100644 --- a/warpgate-common/src/auth/cred.rs +++ b/warpgate-common/src/auth/cred.rs @@ -13,8 +13,6 @@ pub enum CredentialKind { Otp, #[serde(rename = "sso")] Sso, - #[serde(rename = "web")] - WebUserApproval, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -29,7 +27,6 @@ pub enum AuthCredential { provider: String, email: String, }, - WebUserApproval, } impl AuthCredential { @@ -39,7 +36,6 @@ impl AuthCredential { Self::PublicKey { .. } => CredentialKind::PublicKey, Self::Otp { .. } => CredentialKind::Otp, Self::Sso { .. } => CredentialKind::Sso, - Self::WebUserApproval => CredentialKind::WebUserApproval, } } } diff --git a/warpgate-common/src/auth/policy.rs b/warpgate-common/src/auth/policy.rs index 7d2df63..d349a59 100644 --- a/warpgate-common/src/auth/policy.rs +++ b/warpgate-common/src/auth/policy.rs @@ -1,10 +1,12 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use super::{AuthCredential, CredentialKind}; +use crate::UserRequireCredentialsPolicy; pub enum CredentialPolicyResponse { Ok, - Need(HashSet), + NeedMoreCredentials, + Need(CredentialKind), } pub trait CredentialPolicy { @@ -15,71 +17,36 @@ pub trait CredentialPolicy { ) -> CredentialPolicyResponse; } -pub struct AnySingleCredentialPolicy { - pub supported_credential_types: HashSet, -} - -pub struct AllCredentialsPolicy { - pub required_credential_types: HashSet, - pub supported_credential_types: HashSet, -} - -pub struct PerProtocolCredentialPolicy { - pub protocols: HashMap<&'static str, Box>, - pub default: Box, -} - -impl CredentialPolicy for AnySingleCredentialPolicy { - fn is_sufficient( - &self, - _protocol: &str, - valid_credentials: &[AuthCredential], - ) -> CredentialPolicyResponse { - if valid_credentials.is_empty() { - CredentialPolicyResponse::Need( - self.supported_credential_types - .clone() - .into_iter() - .collect(), - ) - } else { - CredentialPolicyResponse::Ok - } - } -} - -impl CredentialPolicy for AllCredentialsPolicy { - fn is_sufficient( - &self, - _protocol: &str, - valid_credentials: &[AuthCredential], - ) -> CredentialPolicyResponse { - let valid_credential_types: HashSet = - valid_credentials.iter().map(|x| x.kind()).collect(); - - if valid_credential_types.is_superset(&self.required_credential_types) { - CredentialPolicyResponse::Ok - } else { - CredentialPolicyResponse::Need( - self.required_credential_types - .difference(&valid_credential_types) - .cloned() - .collect(), - ) - } - } -} - -impl CredentialPolicy for PerProtocolCredentialPolicy { +impl CredentialPolicy for UserRequireCredentialsPolicy { fn is_sufficient( &self, protocol: &str, valid_credentials: &[AuthCredential], ) -> CredentialPolicyResponse { - if let Some(policy) = self.protocols.get(protocol) { - policy.is_sufficient(protocol, valid_credentials) + let required_kinds = match protocol { + "SSH" => &self.ssh, + "HTTP" => &self.http, + "MySQL" => &self.mysql, + _ => unreachable!(), + }; + if let Some(required_kinds) = required_kinds { + let mut remaining_required_kinds = HashSet::::new(); + remaining_required_kinds.extend(required_kinds); + for kind in required_kinds { + if valid_credentials.iter().any(|x| x.kind() == *kind) { + remaining_required_kinds.remove(kind); + } + } + + if let Some(kind) = remaining_required_kinds.into_iter().next() { + CredentialPolicyResponse::Need(kind) + } else { + CredentialPolicyResponse::Ok + } + } else if valid_credentials.is_empty() { + CredentialPolicyResponse::NeedMoreCredentials } else { - self.default.is_sufficient(protocol, valid_credentials) + CredentialPolicyResponse::Ok } } } diff --git a/warpgate-common/src/auth/state.rs b/warpgate-common/src/auth/state.rs index 41617e9..2ee9ea4 100644 --- a/warpgate-common/src/auth/state.rs +++ b/warpgate-common/src/auth/state.rs @@ -1,66 +1,71 @@ -use uuid::Uuid; +use std::time::{Duration, Instant}; + +use once_cell::sync::Lazy; +use tracing::warn; use super::{AuthCredential, CredentialPolicy, CredentialPolicyResponse}; use crate::AuthResult; +#[allow(clippy::unwrap_used)] +pub static TIMEOUT: Lazy = Lazy::new(|| Duration::from_secs(60 * 10)); + pub struct AuthState { - id: Uuid, username: String, protocol: String, - force_rejected: bool, - policy: Box, + policy: Option>, valid_credentials: Vec, + started_at: Instant, } impl AuthState { - pub(crate) fn new( - id: Uuid, + pub fn new( username: String, protocol: String, - policy: Box, + policy: Option>, ) -> Self { Self { - id, username, protocol, - force_rejected: false, policy, valid_credentials: vec![], + started_at: Instant::now(), } } - pub fn id(&self) -> &Uuid { - &self.id - } - pub fn username(&self) -> &str { &self.username } - pub fn protocol(&self) -> &str { - &self.protocol - } - pub fn add_valid_credential(&mut self, credential: AuthCredential) { self.valid_credentials.push(credential); } - pub fn reject(&mut self) { - self.force_rejected = true; + pub fn is_expired(&self) -> bool { + self.started_at.elapsed() > *TIMEOUT } pub fn verify(&self) -> AuthResult { - if self.force_rejected { + if self.valid_credentials.is_empty() { + warn!( + username=%self.username, + "No matching valid credentials" + ); return AuthResult::Rejected; } - match self - .policy - .is_sufficient(&self.protocol, &self.valid_credentials[..]) - { - CredentialPolicyResponse::Ok => AuthResult::Accepted { - username: self.username.clone(), - }, - CredentialPolicyResponse::Need(kinds) => AuthResult::Need(kinds), + + if let Some(ref policy) = self.policy { + match policy.is_sufficient(&self.protocol, &self.valid_credentials[..]) { + CredentialPolicyResponse::Ok => {} + CredentialPolicyResponse::Need(kind) => { + return AuthResult::Need(kind); + } + CredentialPolicyResponse::NeedMoreCredentials => { + return AuthResult::Rejected; + } + } + } + AuthResult::Accepted { + username: self.username.clone(), } } } diff --git a/warpgate-common/src/auth/store.rs b/warpgate-common/src/auth/store.rs index d3a8b8c..9d9a143 100644 --- a/warpgate-common/src/auth/store.rs +++ b/warpgate-common/src/auth/store.rs @@ -1,32 +1,15 @@ use std::collections::HashMap; use std::sync::Arc; -use std::time::{Duration, Instant}; -use once_cell::sync::Lazy; -use tokio::sync::{broadcast, Mutex}; +use tokio::sync::Mutex; use uuid::Uuid; use super::AuthState; -use crate::{AuthResult, ConfigProvider, WarpgateError}; - -#[allow(clippy::unwrap_used)] -pub static TIMEOUT: Lazy = Lazy::new(|| Duration::from_secs(60 * 10)); - -struct AuthCompletionSignal { - sender: broadcast::Sender, - created_at: Instant, -} - -impl AuthCompletionSignal { - pub fn is_expired(&self) -> bool { - self.created_at.elapsed() > *TIMEOUT - } -} +use crate::{ConfigProvider, WarpgateError}; pub struct AuthStateStore { config_provider: Arc>, - store: HashMap>, Instant)>, - completion_signals: HashMap, + store: HashMap, } impl AuthStateStore { @@ -34,66 +17,47 @@ impl AuthStateStore { Self { store: HashMap::new(), config_provider, - completion_signals: HashMap::new(), } } - pub fn contains_key(&self, id: &Uuid) -> bool { + pub fn contains_key(&mut self, id: &Uuid) -> bool { self.store.contains_key(id) } - pub fn get(&self, id: &Uuid) -> Option>> { - self.store.get(id).map(|x| x.0.clone()) + pub fn get_mut(&mut self, id: &Uuid) -> Option<&mut AuthState> { + self.store.get_mut(id) } pub async fn create( &mut self, username: &str, protocol: &str, - ) -> Result<(Uuid, Arc>), WarpgateError> { + ) -> Result<(Uuid, &mut AuthState), WarpgateError> { let id = Uuid::new_v4(); - let Some(policy) = self.config_provider - .lock() - .await - .get_credential_policy(username) - .await? else { - return Err(WarpgateError::UserNotFound) - }; - - let state = AuthState::new(id, username.to_string(), protocol.to_string(), policy); - self.store - .insert(id, (Arc::new(Mutex::new(state)), Instant::now())); + let state = AuthState::new( + username.to_string(), + protocol.to_string(), + self.config_provider + .lock() + .await + .get_credential_policy(username) + .await?, + ); + self.store.insert(id, state); #[allow(clippy::unwrap_used)] - Ok((id, self.get(&id).unwrap())) - } - - pub fn subscribe(&mut self, id: Uuid) -> broadcast::Receiver { - let signal = self.completion_signals.entry(id).or_insert_with(|| { - let (sender, _) = broadcast::channel(1); - AuthCompletionSignal { - sender, - created_at: Instant::now(), - } - }); - - signal.sender.subscribe() - } - - pub async fn complete(&mut self, id: &Uuid) { - let Some((state, _)) = self.store.get(id) else { - return - }; - if let Some(sig) = self.completion_signals.remove(id) { - let _ = sig.sender.send(state.lock().await.verify()); - } + Ok((id, self.store.get_mut(&id).unwrap())) } pub async fn vacuum(&mut self) { - self.store - .retain(|_, (_, started_at)| started_at.elapsed() < *TIMEOUT); - - self.completion_signals - .retain(|_, signal| !signal.is_expired()); + let mut to_remove = vec![]; + for (id, state) in self.store.iter() { + if state.is_expired() { + to_remove.push(*id); + } + } + for id in to_remove { + self.store.remove(&id); + } } } diff --git a/warpgate-common/src/config/mod.rs b/warpgate-common/src/config/mod.rs index 82e6081..bb3bd87 100644 --- a/warpgate-common/src/config/mod.rs +++ b/warpgate-common/src/config/mod.rs @@ -5,12 +5,11 @@ use std::time::Duration; use poem_openapi::{Enum, Object, Union}; use serde::{Deserialize, Serialize}; -use url::Url; use warpgate_sso::SsoProviderConfig; use crate::auth::CredentialKind; use crate::helpers::otp::OtpSecretKey; -use crate::{ListenEndpoint, Secret, WarpgateError}; +use crate::{ListenEndpoint, Secret}; const fn _default_true() -> bool { true @@ -427,24 +426,3 @@ pub struct WarpgateConfig { pub store: WarpgateConfigStore, pub paths_relative_to: PathBuf, } - -impl WarpgateConfig { - pub fn construct_external_url( - &self, - fallback_host: Option<&str>, - ) -> Result { - let ext_host = self.store.external_host.as_deref().or(fallback_host); - let Some(ext_host) = ext_host else { - return Err(WarpgateError::ExternalHostNotSet); - }; - let ext_port = self.store.http.listen.port(); - - let mut url = Url::parse(&format!("https://{ext_host}/"))?; - - if ext_port != 443 { - let _ = url.set_port(Some(ext_port)); - } - - Ok(url) - } -} diff --git a/warpgate-common/src/config_providers/file.rs b/warpgate-common/src/config_providers/file.rs index 6a87a86..b5eb0dc 100644 --- a/warpgate-common/src/config_providers/file.rs +++ b/warpgate-common/src/config_providers/file.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::sync::Arc; use async_trait::async_trait; @@ -11,10 +11,7 @@ use uuid::Uuid; use warpgate_db_entities::Ticket; use super::ConfigProvider; -use crate::auth::{ - AllCredentialsPolicy, AnySingleCredentialPolicy, AuthCredential, CredentialKind, - CredentialPolicy, PerProtocolCredentialPolicy, -}; +use crate::auth::{AuthCredential, CredentialPolicy}; use crate::helpers::hash::verify_password_hash; use crate::helpers::otp::verify_totp; use crate::{Target, User, UserAuthCredential, UserSnapshot, WarpgateConfig, WarpgateError}; @@ -81,52 +78,9 @@ impl ConfigProvider for FileConfigProvider { return Ok(None); }; - let supported_credential_types: HashSet = - user.credentials.iter().map(|x| x.kind()).collect(); - let default_policy = Box::new(AnySingleCredentialPolicy { - supported_credential_types: supported_credential_types.clone(), - }) as Box; - - if let Some(req) = user.require { - let mut policy = PerProtocolCredentialPolicy { - default: default_policy, - protocols: HashMap::new(), - }; - - if let Some(p) = req.http { - policy.protocols.insert( - "HTTP", - Box::new(AllCredentialsPolicy { - supported_credential_types: supported_credential_types.clone(), - required_credential_types: p.into_iter().collect(), - }), - ); - } - if let Some(p) = req.mysql { - policy.protocols.insert( - "MySQL", - Box::new(AllCredentialsPolicy { - supported_credential_types: supported_credential_types.clone(), - required_credential_types: p.into_iter().collect(), - }), - ); - } - if let Some(p) = req.ssh { - policy.protocols.insert( - "SSH", - Box::new(AllCredentialsPolicy { - supported_credential_types, - required_credential_types: p.into_iter().collect(), - }), - ); - } - - Ok(Some( - Box::new(policy) as Box - )) - } else { - Ok(Some(default_policy)) - } + Ok(user + .require + .map(|r| Box::new(r) as Box)) } async fn username_for_sso_credential( @@ -239,7 +193,6 @@ impl ConfigProvider for FileConfigProvider { } return Ok(false); } - _ => return Err(WarpgateError::InvalidCredentialType), } } diff --git a/warpgate-common/src/config_providers/mod.rs b/warpgate-common/src/config_providers/mod.rs index f23d1c5..3209d3f 100644 --- a/warpgate-common/src/config_providers/mod.rs +++ b/warpgate-common/src/config_providers/mod.rs @@ -1,5 +1,4 @@ mod file; -use std::collections::HashSet; use std::sync::Arc; use async_trait::async_trait; @@ -13,10 +12,11 @@ use warpgate_db_entities::Ticket; use crate::auth::{AuthCredential, CredentialKind, CredentialPolicy}; use crate::{Secret, Target, UserSnapshot, WarpgateError}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub enum AuthResult { Accepted { username: String }, - Need(HashSet), + Need(CredentialKind), + NeedMoreCredentials, Rejected, } diff --git a/warpgate-common/src/error.rs b/warpgate-common/src/error.rs index 1eeb26d..f37cc95 100644 --- a/warpgate-common/src/error.rs +++ b/warpgate-common/src/error.rs @@ -9,16 +9,8 @@ pub enum WarpgateError { DatabaseError(#[from] sea_orm::DbErr), #[error("ticket not found: {0}")] InvalidTicket(Uuid), - #[error("invalid credential type")] - InvalidCredentialType, #[error(transparent)] Other(Box), - #[error("user not found")] - UserNotFound, - #[error("failed to parse url: {0}")] - UrlParse(#[from] url::ParseError), - #[error("external_url config option is not set")] - ExternalHostNotSet, } impl ResponseError for WarpgateError { diff --git a/warpgate-protocol-http/Cargo.toml b/warpgate-protocol-http/Cargo.toml index c83819d..41293b1 100644 --- a/warpgate-protocol-http/Cargo.toml +++ b/warpgate-protocol-http/Cargo.toml @@ -7,7 +7,6 @@ version = "0.4.0" [dependencies] anyhow = "1.0" async-trait = "0.1" -chrono = {version = "0.4", features = ["serde"]} cookie = "0.16" data-encoding = "2.3" delegate = "0.6" diff --git a/warpgate-protocol-http/src/api/auth.rs b/warpgate-protocol-http/src/api/auth.rs index 04d06a0..9a9e019 100644 --- a/warpgate-protocol-http/src/api/auth.rs +++ b/warpgate-protocol-http/src/api/auth.rs @@ -3,18 +3,14 @@ use std::sync::Arc; use poem::session::Session; use poem::web::Data; use poem::Request; -use poem_openapi::param::Path; use poem_openapi::payload::Json; use poem_openapi::{ApiResponse, Enum, Object, OpenApi}; use tokio::sync::Mutex; use tracing::*; -use uuid::Uuid; -use warpgate_common::auth::{AuthCredential, AuthState, CredentialKind}; +use warpgate_common::auth::{AuthCredential, CredentialKind}; use warpgate_common::{AuthResult, Secret, Services}; -use crate::common::{ - authorize_session, endpoint_auth, get_auth_state_for_request, SessionAuthorization, SessionExt, -}; +use crate::common::{authorize_session, get_auth_state_for_request, SessionExt}; use crate::session::SessionStore; pub struct Api; @@ -37,8 +33,6 @@ enum ApiAuthState { PasswordNeeded, OtpNeeded, SsoNeeded, - WebUserApprovalNeeded, - PublicKeyNeeded, Success, } @@ -64,7 +58,6 @@ enum LogoutResponse { #[derive(Object)] struct AuthStateResponseInternal { - pub protocol: String, pub state: ApiAuthState, } @@ -72,22 +65,17 @@ struct AuthStateResponseInternal { enum AuthStateResponse { #[oai(status = 200)] Ok(Json), - #[oai(status = 404)] - NotFound, } impl From for ApiAuthState { fn from(state: AuthResult) -> Self { match state { AuthResult::Rejected => ApiAuthState::Failed, - AuthResult::Need(kinds) => match kinds.iter().next() { - Some(CredentialKind::Password) => ApiAuthState::PasswordNeeded, - Some(CredentialKind::Otp) => ApiAuthState::OtpNeeded, - Some(CredentialKind::Sso) => ApiAuthState::SsoNeeded, - Some(CredentialKind::WebUserApproval) => ApiAuthState::WebUserApprovalNeeded, - Some(CredentialKind::PublicKey) => ApiAuthState::PublicKeyNeeded, - None => ApiAuthState::Failed, - }, + AuthResult::Need(CredentialKind::Password) => ApiAuthState::PasswordNeeded, + AuthResult::Need(CredentialKind::Otp) => ApiAuthState::OtpNeeded, + AuthResult::Need(CredentialKind::Sso) => ApiAuthState::SsoNeeded, + AuthResult::Need(CredentialKind::PublicKey) => ApiAuthState::Failed, + AuthResult::NeedMoreCredentials => ApiAuthState::Failed, AuthResult::Accepted { .. } => ApiAuthState::Success, } } @@ -104,9 +92,8 @@ impl Api { body: Json, ) -> poem::Result { let mut auth_state_store = services.auth_state_store.lock().await; - let state_arc = + let state = get_auth_state_for_request(&body.username, session, &mut auth_state_store).await?; - let mut state = state_arc.lock().await; let mut cp = services.config_provider.lock().await; @@ -120,7 +107,6 @@ impl Api { match state.verify() { AuthResult::Accepted { username } => { - auth_state_store.complete(state.id()).await; authorize_session(req, username).await?; Ok(LoginResponse::Success) } @@ -145,14 +131,12 @@ impl Api { let mut auth_state_store = services.auth_state_store.lock().await; - let Some(state_arc) = state_id.and_then(|id| auth_state_store.get(&id.0)) else { + let Some(state) = state_id.and_then(|id| auth_state_store.get_mut(&id.0)) else { return Ok(LoginResponse::Failure(Json(LoginFailureResponse { state: ApiAuthState::NotStarted, }))) }; - let mut state = state_arc.lock().await; - let mut cp = services.config_provider.lock().await; let otp_cred = AuthCredential::Otp(body.otp.clone().into()); @@ -162,7 +146,6 @@ impl Api { match state.verify() { AuthResult::Accepted { username } => { - auth_state_store.complete(state.id()).await; authorize_session(req, username).await?; Ok(LoginResponse::Success) } @@ -172,6 +155,27 @@ impl Api { } } + #[oai(path = "/auth/state", method = "get", operation_id = "getAuthState")] + async fn api_auth_state( + &self, + session: &Session, + services: Data<&Services>, + ) -> poem::Result { + let state_id = session.get_auth_state_id(); + + let mut auth_state_store = services.auth_state_store.lock().await; + + let Some(state) = state_id.and_then(|id| auth_state_store.get_mut(&id.0)) else { + return Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal { + state: ApiAuthState::NotStarted, + }))); + }; + + Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal { + state: state.verify().into(), + }))) + } + #[oai(path = "/auth/logout", method = "post", operation_id = "logout")] async fn api_auth_logout( &self, @@ -183,129 +187,4 @@ impl Api { info!("Logged out"); Ok(LogoutResponse::Success) } - - #[oai( - path = "/auth/state", - method = "get", - operation_id = "getDefaultAuthState" - )] - async fn api_default_auth_state( - &self, - session: &Session, - services: Data<&Services>, - ) -> poem::Result { - let Some(state_id) = session.get_auth_state_id() else { - return Ok(AuthStateResponse::NotFound) - }; - let store = services.auth_state_store.lock().await; - let Some(state_arc) = store.get(&state_id.0) else { - return Ok(AuthStateResponse::NotFound); - }; - serialize_auth_state_inner(state_arc).await - } - - #[oai( - path = "/auth/state/:id", - method = "get", - operation_id = "get_auth_state", - transform = "endpoint_auth" - )] - async fn api_auth_state( - &self, - services: Data<&Services>, - auth: Option>, - id: Path, - ) -> poem::Result { - let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else { - return Ok(AuthStateResponse::NotFound); - }; - serialize_auth_state_inner(state_arc).await - } - - #[oai( - path = "/auth/state/:id/approve", - method = "post", - operation_id = "approve_auth", - transform = "endpoint_auth" - )] - async fn api_approve_auth( - &self, - services: Data<&Services>, - auth: Option>, - id: Path, - ) -> poem::Result { - let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else { - return Ok(AuthStateResponse::NotFound); - }; - - let auth_result = { - let mut state = state_arc.lock().await; - state.add_valid_credential(AuthCredential::WebUserApproval); - state.verify() - }; - - if let AuthResult::Accepted { .. } = auth_result { - services.auth_state_store.lock().await.complete(&*id).await; - } - serialize_auth_state_inner(state_arc).await - } - - #[oai( - path = "/auth/state/:id/reject", - method = "post", - operation_id = "reject_auth", - transform = "endpoint_auth" - )] - async fn api_reject_auth( - &self, - services: Data<&Services>, - auth: Option>, - id: Path, - ) -> poem::Result { - let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else { - return Ok(AuthStateResponse::NotFound); - }; - state_arc.lock().await.reject(); - services.auth_state_store.lock().await.complete(&*id).await; - serialize_auth_state_inner(state_arc).await - } -} - -async fn get_auth_state( - id: &Uuid, - services: &Services, - auth: Option<&SessionAuthorization>, -) -> Option>> { - let store = services.auth_state_store.lock().await; - - let Some(auth) = auth else { - return None; - }; - - let SessionAuthorization::User(username) = auth else { - return None; - }; - - let Some(state_arc) = store.get(&*id) else { - return None; - }; - - { - let state = state_arc.lock().await; - if state.username() != username { - return None; - } - } - - Some(state_arc) -} - -async fn serialize_auth_state_inner( - state_arc: Arc>, -) -> poem::Result { - let state = state_arc.lock().await; - Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal { - protocol: state.protocol().to_string(), - state: state.verify().into(), - }))) } diff --git a/warpgate-protocol-http/src/api/sso_provider_detail.rs b/warpgate-protocol-http/src/api/sso_provider_detail.rs index ee64d58..9b4f106 100644 --- a/warpgate-protocol-http/src/api/sso_provider_detail.rs +++ b/warpgate-protocol-http/src/api/sso_provider_detail.rs @@ -1,9 +1,10 @@ use poem::session::Session; use poem::web::Data; use poem::Request; -use poem_openapi::param::{Path, Query}; +use poem_openapi::param::Path; use poem_openapi::payload::Json; use poem_openapi::{ApiResponse, Object, OpenApi}; +use reqwest::Url; use serde::{Deserialize, Serialize}; use warpgate_common::Services; use warpgate_sso::{SsoClient, SsoLoginRequest}; @@ -30,7 +31,6 @@ pub static SSO_CONTEXT_SESSION_KEY: &str = "sso_request"; pub struct SsoContext { pub provider: String, pub request: SsoLoginRequest, - pub next_url: Option, } #[OpenApi] @@ -46,14 +46,31 @@ impl Api { session: &Session, services: Data<&Services>, name: Path, - next: Query>, ) -> poem::Result { let config = services.config.lock().await; let name = name.0; + let ext_host = config + .store + .external_host + .as_deref() + .or_else(|| req.original_uri().host()); + let Some(ext_host) = ext_host else { + return Err(poem::Error::from_string("external_host config option is required for SSO", http::status::StatusCode::INTERNAL_SERVER_ERROR)); + }; + let ext_port = config.store.http.listen.port(); - let mut return_url = config.construct_external_url(req.original_uri().host())?; - return_url.set_path("@warpgate/api/sso/return"); + let mut return_url = Url::parse(&format!("https://{ext_host}/@warpgate/api/sso/return")) + .map_err(|e| { + poem::Error::from_string( + format!("failed to construct the return URL: {e}"), + http::status::StatusCode::INTERNAL_SERVER_ERROR, + ) + })?; + + if ext_port != 443 { + let _ = return_url.set_port(Some(ext_port)); + } let Some(provider_config) = config.store.sso_providers.iter().find(|p| p.name == *name) else { return Ok(StartSsoResponse::NotFound); @@ -67,14 +84,10 @@ impl Api { .map_err(poem::error::InternalServerError)?; let url = sso_req.auth_url().to_string(); - session.set( - SSO_CONTEXT_SESSION_KEY, - SsoContext { - provider: name, - request: sso_req, - next_url: next.0.clone(), - }, - ); + session.set(SSO_CONTEXT_SESSION_KEY, SsoContext { + provider: name, + request: sso_req, + }); Ok(StartSsoResponse::Ok(Json(StartSsoResponseParams { url }))) } diff --git a/warpgate-protocol-http/src/api/sso_provider_list.rs b/warpgate-protocol-http/src/api/sso_provider_list.rs index 27524cc..3a218af 100644 --- a/warpgate-protocol-http/src/api/sso_provider_list.rs +++ b/warpgate-protocol-http/src/api/sso_provider_list.rs @@ -120,9 +120,8 @@ impl Api { }; let mut auth_state_store = services.auth_state_store.lock().await; - let state_arc = get_auth_state_for_request(&username, session, &mut auth_state_store).await?; + let state = get_auth_state_for_request(&username, session, &mut auth_state_store).await?; - let mut state = state_arc.lock().await; let mut cp = services.config_provider.lock().await; if cp.validate_credential(&username, &cred).await? { @@ -131,15 +130,11 @@ impl Api { match state.verify() { AuthResult::Accepted { username } => { - auth_state_store.complete(state.id()).await; authorize_session(req, username).await?; } - _ => (), + _ => () } - Ok(Response::new(ReturnToSsoResponse::Ok).header( - "Location", - context.next_url.as_deref().unwrap_or("/@warpgate#/login"), - )) + Ok(Response::new(ReturnToSsoResponse::Ok).header("Location", "/@warpgate")) } } diff --git a/warpgate-protocol-http/src/common.rs b/warpgate-protocol-http/src/common.rs index c767ab8..c6ef118 100644 --- a/warpgate-protocol-http/src/common.rs +++ b/warpgate-protocol-http/src/common.rs @@ -170,18 +170,18 @@ pub fn gateway_redirect(req: &Request) -> Response { .unwrap_or("".into()); let path = format!( - "/@warpgate#/login?next={}", + "/@warpgate?next={}", utf8_percent_encode(&path, NON_ALPHANUMERIC), ); Redirect::temporary(path).into_response() } -pub async fn get_auth_state_for_request( +pub async fn get_auth_state_for_request<'a>( username: &str, session: &Session, - store: &mut AuthStateStore, -) -> Result>, WarpgateError> { + store: &'a mut AuthStateStore, +) -> Result<&'a mut AuthState, WarpgateError> { match session.get_auth_state_id() { Some(id) => { if !store.contains_key(&id.0) { @@ -192,7 +192,7 @@ pub async fn get_auth_state_for_request( }; match session.get_auth_state_id() { - Some(id) => Ok(store.get(&id.0).unwrap()), + Some(id) => Ok(store.get_mut(&id.0).unwrap()), None => { let (id, state) = store .create(&username, crate::common::PROTOCOL_NAME) diff --git a/warpgate-protocol-mysql/src/session.rs b/warpgate-protocol-mysql/src/session.rs index 537746b..54a4090 100644 --- a/warpgate-protocol-mysql/src/session.rs +++ b/warpgate-protocol-mysql/src/session.rs @@ -7,7 +7,7 @@ use tokio::net::TcpStream; use tokio::sync::Mutex; use tracing::*; use uuid::Uuid; -use warpgate_common::auth::{AuthCredential, AuthSelector}; +use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState}; use warpgate_common::helpers::rng::get_crypto_rng; use warpgate_common::{ authorize_ticket, AuthResult, Secret, Services, TargetMySqlOptions, TargetOptions, @@ -180,20 +180,15 @@ impl MySqlSession { username, target_name, } => { - let state_arc = self - .services - .auth_state_store - .lock() - .await - .create(&username, crate::common::PROTOCOL_NAME) - .await? - .1; - let mut state = state_arc.lock().await; - let user_auth_result = { - let credential = AuthCredential::Password(password); - let mut cp = self.services.config_provider.lock().await; + + let credential = AuthCredential::Password(password); + let mut state = AuthState::new( + username.clone(), + crate::common::PROTOCOL_NAME.to_string(), + cp.get_credential_policy(&username).await?, + ); if cp.validate_credential(&username, &credential).await? { state.add_valid_credential(credential); } @@ -203,12 +198,6 @@ impl MySqlSession { match user_auth_result { AuthResult::Accepted { username } => { - self.services - .auth_state_store - .lock() - .await - .complete(state.id()) - .await; let target_auth_result = { self.services .config_provider @@ -227,7 +216,9 @@ impl MySqlSession { } self.run_authorized(handshake, username, target_name).await } - AuthResult::Rejected | AuthResult::Need(_) => fail(&mut self).await, // TODO SSO + AuthResult::Rejected + | AuthResult::Need(_) + | AuthResult::NeedMoreCredentials => fail(&mut self).await, // TODO SSO } } AuthSelector::Ticket { secret } => { diff --git a/warpgate-protocol-ssh/Cargo.toml b/warpgate-protocol-ssh/Cargo.toml index c12abb1..6dfb664 100644 --- a/warpgate-protocol-ssh/Cargo.toml +++ b/warpgate-protocol-ssh/Cargo.toml @@ -12,7 +12,7 @@ bimap = "0.6" bytes = "1.2" dialoguer = "0.10" futures = "0.3" -russh = {version = "0.34.0-beta.8", features = ["openssl"]} +russh = {version = "0.34.0-beta.7", features = ["openssl"]} russh-keys = {version = "0.22.0-beta.4", features = ["openssl"]} sea-orm = {version = "^0.9", features = ["runtime-tokio-native-tls"], default-features = false} thiserror = "1.0" diff --git a/warpgate-protocol-ssh/src/server/mod.rs b/warpgate-protocol-ssh/src/server/mod.rs index 0c5a1b8..ac6699b 100644 --- a/warpgate-protocol-ssh/src/server/mod.rs +++ b/warpgate-protocol-ssh/src/server/mod.rs @@ -23,7 +23,6 @@ pub async fn run_server(services: Services, address: SocketAddr) -> Result<()> { let config = services.config.lock().await; russh::server::Config { auth_rejection_time: std::time::Duration::from_secs(1), - connection_timeout: Some(std::time::Duration::from_secs(300)), methods: MethodSet::PUBLICKEY | MethodSet::PASSWORD | MethodSet::KEYBOARD_INTERACTIVE, keys: load_host_keys(&config)?, ..Default::default() diff --git a/warpgate-protocol-ssh/src/server/session.rs b/warpgate-protocol-ssh/src/server/session.rs index 804c542..52ca34a 100644 --- a/warpgate-protocol-ssh/src/server/session.rs +++ b/warpgate-protocol-ssh/src/server/session.rs @@ -1,6 +1,6 @@ use std::borrow::Cow; use std::collections::hash_map::Entry::Vacant; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::net::{Ipv4Addr, SocketAddr}; use std::str::FromStr; use std::sync::Arc; @@ -10,11 +10,11 @@ use anyhow::{Context, Result}; use bimap::BiMap; use bytes::{Bytes, BytesMut}; use russh::server::Session; -use russh::{CryptoVec, MethodSet, Sig}; +use russh::{CryptoVec, Sig}; use russh_keys::key::PublicKey; use russh_keys::PublicKeyBase64; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; -use tokio::sync::{broadcast, oneshot, Mutex}; +use tokio::sync::{oneshot, Mutex}; use tracing::*; use uuid::Uuid; use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState, CredentialKind}; @@ -53,12 +53,6 @@ enum Event { Client(RCEvent), } -enum KeyboardInteractiveState { - None, - OtpRequested, - WebAuthRequested(broadcast::Receiver), -} - pub struct ServerSession { pub id: SessionId, username: Option, @@ -80,8 +74,7 @@ pub struct ServerSession { hub: EventHub, event_sender: EventSender, service_output: ServiceOutput, - auth_state: Option>>, - keyboard_interactive_state: KeyboardInteractiveState, + auth_state: Option, } fn session_debug_tag(id: &SessionId, remote_address: &SocketAddr) -> String { @@ -149,7 +142,6 @@ impl ServerSession { so_tx.send(BytesMut::from(data).freeze()).context("x") })), auth_state: None, - keyboard_interactive_state: KeyboardInteractiveState::None, }; let this = Arc::new(Mutex::new(this)); @@ -225,23 +217,18 @@ impl ServerSession { Ok(this) } - async fn get_auth_state(&mut self, username: &str) -> Result>> { + async fn get_auth_state(&mut self, username: &str) -> Result<&mut AuthState> { #[allow(clippy::unwrap_used)] - if self.auth_state.is_none() - || self.auth_state.as_ref().unwrap().lock().await.username() != username - { - let state = self - .services - .auth_state_store - .lock() - .await - .create(username, crate::PROTOCOL_NAME) - .await? - .1; - self.auth_state = Some(state); + if self.auth_state.is_none() || self.auth_state.as_ref().unwrap().username() != username { + let mut cp = self.services.config_provider.lock().await; + self.auth_state = Some(AuthState::new( + username.to_string(), + crate::PROTOCOL_NAME.to_string(), + cp.get_credential_policy(username).await?, + )); } #[allow(clippy::unwrap_used)] - Ok(self.auth_state.as_ref().map(Clone::clone).unwrap()) + Ok(self.auth_state.as_mut().unwrap()) } pub fn make_logging_span(&self) -> tracing::Span { @@ -952,17 +939,13 @@ impl ServerSession { .await { Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept, - Ok(AuthResult::Rejected) => russh::server::Auth::Reject { - proceed_with_methods: Some(MethodSet::all()), - }, - Ok(AuthResult::Need(kinds)) => russh::server::Auth::Reject { - proceed_with_methods: Some(self.get_remaining_auth_methods(kinds)), - }, + Ok(AuthResult::Rejected) => russh::server::Auth::Reject, + Ok(AuthResult::Need(_) | AuthResult::NeedMoreCredentials) => { + russh::server::Auth::Reject + } Err(error) => { error!(?error, "Failed to verify credentials"); - russh::server::Auth::Reject { - proceed_with_methods: None, - } + russh::server::Auth::Reject } } } @@ -980,17 +963,13 @@ impl ServerSession { .await { Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept, - Ok(AuthResult::Rejected) => russh::server::Auth::Reject { - proceed_with_methods: None, - }, - Ok(AuthResult::Need(_)) => russh::server::Auth::Reject { - proceed_with_methods: None, - }, + Ok(AuthResult::Rejected) => russh::server::Auth::Reject, + Ok(AuthResult::Need(_) | AuthResult::NeedMoreCredentials) => { + russh::server::Auth::Reject + } Err(error) => { error!(?error, "Failed to verify credentials"); - russh::server::Auth::Reject { - proceed_with_methods: None, - } + russh::server::Auth::Reject } } } @@ -1003,111 +982,27 @@ impl ServerSession { let selector: AuthSelector = ssh_username.expose_secret().into(); info!("Keyboard-interactive auth as {:?}", selector); - let cred; - match &mut self.keyboard_interactive_state { - KeyboardInteractiveState::None => { - cred = None; - } - KeyboardInteractiveState::OtpRequested => { - cred = response.map(AuthCredential::Otp); - } - KeyboardInteractiveState::WebAuthRequested(event) => { - cred = None; - let _ = event.recv().await; - // the auth state has been updated by now - } - } - - self.keyboard_interactive_state = KeyboardInteractiveState::None; + let cred = response.map(AuthCredential::Otp); match self.try_auth(&selector, cred).await { Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept, - Ok(AuthResult::Rejected) => russh::server::Auth::Reject { - proceed_with_methods: None, + Ok( + AuthResult::Rejected + | AuthResult::NeedMoreCredentials + | AuthResult::Need(CredentialKind::Otp), + ) => russh::server::Auth::Partial { + name: Cow::Borrowed("Two-factor authentication"), + instructions: Cow::Borrowed(""), + prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]), }, - Ok(AuthResult::Need(kinds)) => { - if kinds.contains(&CredentialKind::Otp) { - self.keyboard_interactive_state = KeyboardInteractiveState::OtpRequested; - russh::server::Auth::Partial { - name: Cow::Borrowed("Two-factor authentication"), - instructions: Cow::Borrowed(""), - prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]), - } - } else if kinds.contains(&CredentialKind::WebUserApproval) { - let Some(auth_state) = self.auth_state.as_ref() else { - return russh::server::Auth::Reject { proceed_with_methods: None}; - }; - let auth_state_id = *auth_state.lock().await.id(); - let event = self - .services - .auth_state_store - .lock() - .await - .subscribe(auth_state_id); - self.keyboard_interactive_state = - KeyboardInteractiveState::WebAuthRequested(event); - - let mut login_url = match self - .services - .config - .lock() - .await - .construct_external_url(None) - { - Ok(url) => url, - Err(error) => { - error!(?error, "Failed to construct external URL"); - return russh::server::Auth::Reject { - proceed_with_methods: None, - }; - } - }; - - login_url.set_path("@warpgate"); - login_url - .set_fragment(Some(&format!("/login?next=%2Flogin%2F{auth_state_id}"))); - - russh::server::Auth::Partial { - name: Cow::Owned(format!( - concat!( - "----------------------------------------------------------------\n", - "Warpgate authentication: please open {} in your browser\n", - "----------------------------------------------------------------\n" - ), - login_url - )), - instructions: Cow::Borrowed(""), - prompts: Cow::Owned(vec![(Cow::Borrowed("Press Enter when done: "), true)]), - } - } else { - russh::server::Auth::Reject { - proceed_with_methods: None, - } - } - } + Ok(AuthResult::Need(_)) => russh::server::Auth::Reject, // TODO SSO Err(error) => { error!(?error, "Failed to verify credentials"); - russh::server::Auth::Reject { - proceed_with_methods: None, - } + russh::server::Auth::Reject } } } - fn get_remaining_auth_methods(&self, kinds: HashSet) -> MethodSet { - let mut m = MethodSet::empty(); - for kind in kinds { - match kind { - CredentialKind::Password => m.insert(MethodSet::PASSWORD), - CredentialKind::Otp => m.insert(MethodSet::KEYBOARD_INTERACTIVE), - CredentialKind::WebUserApproval => m.insert(MethodSet::KEYBOARD_INTERACTIVE), - CredentialKind::PublicKey => m.insert(MethodSet::PUBLICKEY), - CredentialKind::Sso => m.insert(MethodSet::KEYBOARD_INTERACTIVE), - } - } - m - } - async fn try_auth( &mut self, selector: &AuthSelector, @@ -1119,9 +1014,7 @@ impl ServerSession { target_name, } => { let cp = self.services.config_provider.clone(); - - let state_arc = self.get_auth_state(username).await?; - let mut state = state_arc.lock().await; + let state = self.get_auth_state(username).await?; if let Some(credential) = credential { if cp @@ -1138,12 +1031,6 @@ impl ServerSession { match user_auth_result { AuthResult::Accepted { username } => { - self.services - .auth_state_store - .lock() - .await - .complete(state.id()) - .await; let target_auth_result = { self.services .config_provider diff --git a/warpgate-web/src/gateway/App.svelte b/warpgate-web/src/gateway/App.svelte index 7d4e016..a59c834 100644 --- a/warpgate-web/src/gateway/App.svelte +++ b/warpgate-web/src/gateway/App.svelte @@ -2,25 +2,22 @@ import { faSignOut } from '@fortawesome/free-solid-svg-icons' import { Alert, Spinner } from 'sveltestrap' import Fa from 'svelte-fa' -import Router, { push } from 'svelte-spa-router' -import { wrap } from 'svelte-spa-router/wrap' -import { get } from 'svelte/store' import { api } from 'gateway/lib/api' import { reloadServerInfo, serverInfo } from 'gateway/lib/store' import ThemeSwitcher from 'common/ThemeSwitcher.svelte' +import Login from './Login.svelte' +import TargetList from './TargetList.svelte' import Logo from 'common/Logo.svelte' let redirecting = false -let serverInfoPromise = reloadServerInfo() async function init () { - await serverInfoPromise + await reloadServerInfo() } async function logout () { await api.logout() await reloadServerInfo() - push('/login') } function onPageResume () { @@ -28,36 +25,6 @@ function onPageResume () { init() } -async function requireLogin (detail) { - await serverInfoPromise - if (!get(serverInfo)?.username) { - let url = detail.location - if (detail.querystring) { - url += '?' + detail.querystring - } - push('/login?next=' + encodeURIComponent(url)) - return false - } - return true -} - -const routes = { - '/': wrap({ - asyncComponent: () => import('./TargetList.svelte'), - props: { - 'on:navigation': () => redirecting = true, - }, - conditions: [requireLogin], - }), - '/login': wrap({ - asyncComponent: () => import('./Login.svelte'), - }), - '/login/:stateId': wrap({ - asyncComponent: () => import('./OutOfBandAuth.svelte'), - conditions: [requireLogin], - }), -} - init() @@ -71,9 +38,9 @@ init() {:else}
- {#if $serverInfo?.username}
@@ -89,7 +56,12 @@ init()
- + {#if $serverInfo?.username} + redirecting = true} /> + {:else} + + {/if}