Out-of-band SSO (#245)

This commit is contained in:
Eugene 2022-08-08 23:30:18 +02:00 committed by GitHub
parent fbd8d0dda3
commit c6885f18c3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 863 additions and 248 deletions

View file

@ -12,7 +12,6 @@ jobs:
- uses: actions-rs/toolchain@v1 - uses: actions-rs/toolchain@v1
with: with:
toolchain: nightly
target: x86_64-unknown-linux-gnu target: x86_64-unknown-linux-gnu
override: true override: true
@ -42,7 +41,6 @@ jobs:
uses: actions-rs/cargo@v1 uses: actions-rs/cargo@v1
with: with:
command: build command: build
toolchain: nightly
use-cross: true use-cross: true
args: --release --target x86_64-unknown-linux-gnu -Ztarget-applies-to-host args: --release --target x86_64-unknown-linux-gnu -Ztarget-applies-to-host

6
Cargo.lock generated
View file

@ -3143,9 +3143,9 @@ dependencies = [
[[package]] [[package]]
name = "russh" name = "russh"
version = "0.34.0-beta.7" version = "0.34.0-beta.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f3bb72a66e32d52e0e258627d141d5c93b408e050f15033699caa836d064c7e" checksum = "ccd8be93ee0b54a8a6b74c77ecef946185f0acbfb5234ea66666887621381e85"
dependencies = [ dependencies = [
"aes 0.8.1", "aes 0.8.1",
"aes-gcm 0.10.1", "aes-gcm 0.10.1",
@ -4620,6 +4620,7 @@ dependencies = [
"bytes 1.2.1", "bytes 1.2.1",
"chrono", "chrono",
"data-encoding", "data-encoding",
"futures",
"humantime-serde", "humantime-serde",
"lazy_static", "lazy_static",
"once_cell", "once_cell",
@ -4691,6 +4692,7 @@ version = "0.4.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
"chrono",
"cookie", "cookie",
"data-encoding", "data-encoding",
"delegate", "delegate",

1
rust-toolchain Normal file
View file

@ -0,0 +1 @@
nightly-2022-08-01

View file

@ -1,2 +0,0 @@
[toolchain]
channel = "nightly-2022-07-22"

View file

@ -13,6 +13,7 @@ chrono = { version = "0.4", features = ["serde"] }
data-encoding = "2.3" data-encoding = "2.3"
humantime-serde = "1.1" humantime-serde = "1.1"
lazy_static = "1.4" lazy_static = "1.4"
futures = "0.3"
once_cell = "1.10" once_cell = "1.10"
packet = "0.1" packet = "0.1"
password-hash = "0.4" password-hash = "0.4"

View file

@ -13,6 +13,8 @@ pub enum CredentialKind {
Otp, Otp,
#[serde(rename = "sso")] #[serde(rename = "sso")]
Sso, Sso,
#[serde(rename = "web")]
WebUserApproval,
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
@ -27,6 +29,7 @@ pub enum AuthCredential {
provider: String, provider: String,
email: String, email: String,
}, },
WebUserApproval,
} }
impl AuthCredential { impl AuthCredential {
@ -36,6 +39,7 @@ impl AuthCredential {
Self::PublicKey { .. } => CredentialKind::PublicKey, Self::PublicKey { .. } => CredentialKind::PublicKey,
Self::Otp { .. } => CredentialKind::Otp, Self::Otp { .. } => CredentialKind::Otp,
Self::Sso { .. } => CredentialKind::Sso, Self::Sso { .. } => CredentialKind::Sso,
Self::WebUserApproval => CredentialKind::WebUserApproval,
} }
} }
} }

View file

@ -1,12 +1,10 @@
use std::collections::HashSet; use std::collections::{HashMap, HashSet};
use super::{AuthCredential, CredentialKind}; use super::{AuthCredential, CredentialKind};
use crate::UserRequireCredentialsPolicy;
pub enum CredentialPolicyResponse { pub enum CredentialPolicyResponse {
Ok, Ok,
NeedMoreCredentials, Need(HashSet<CredentialKind>),
Need(CredentialKind),
} }
pub trait CredentialPolicy { pub trait CredentialPolicy {
@ -17,36 +15,71 @@ pub trait CredentialPolicy {
) -> CredentialPolicyResponse; ) -> CredentialPolicyResponse;
} }
impl CredentialPolicy for UserRequireCredentialsPolicy { pub struct AnySingleCredentialPolicy {
pub supported_credential_types: HashSet<CredentialKind>,
}
pub struct AllCredentialsPolicy {
pub required_credential_types: HashSet<CredentialKind>,
pub supported_credential_types: HashSet<CredentialKind>,
}
pub struct PerProtocolCredentialPolicy {
pub protocols: HashMap<&'static str, Box<dyn CredentialPolicy + Send + Sync>>,
pub default: Box<dyn CredentialPolicy + Send + Sync>,
}
impl CredentialPolicy for AnySingleCredentialPolicy {
fn is_sufficient( fn is_sufficient(
&self, &self,
protocol: &str, _protocol: &str,
valid_credentials: &[AuthCredential], valid_credentials: &[AuthCredential],
) -> CredentialPolicyResponse { ) -> CredentialPolicyResponse {
let required_kinds = match protocol { if valid_credentials.is_empty() {
"SSH" => &self.ssh, CredentialPolicyResponse::Need(
"HTTP" => &self.http, self.supported_credential_types
"MySQL" => &self.mysql, .clone()
_ => unreachable!(), .into_iter()
}; .collect(),
if let Some(required_kinds) = required_kinds { )
let mut remaining_required_kinds = HashSet::<CredentialKind>::new();
remaining_required_kinds.extend(required_kinds);
for kind in required_kinds {
if valid_credentials.iter().any(|x| x.kind() == *kind) {
remaining_required_kinds.remove(kind);
}
}
if let Some(kind) = remaining_required_kinds.into_iter().next() {
CredentialPolicyResponse::Need(kind)
} else {
CredentialPolicyResponse::Ok
}
} else if valid_credentials.is_empty() {
CredentialPolicyResponse::NeedMoreCredentials
} else { } else {
CredentialPolicyResponse::Ok CredentialPolicyResponse::Ok
} }
} }
} }
impl CredentialPolicy for AllCredentialsPolicy {
fn is_sufficient(
&self,
_protocol: &str,
valid_credentials: &[AuthCredential],
) -> CredentialPolicyResponse {
let valid_credential_types: HashSet<CredentialKind> =
valid_credentials.iter().map(|x| x.kind()).collect();
if valid_credential_types.is_superset(&self.required_credential_types) {
CredentialPolicyResponse::Ok
} else {
CredentialPolicyResponse::Need(
self.required_credential_types
.difference(&valid_credential_types)
.cloned()
.collect(),
)
}
}
}
impl CredentialPolicy for PerProtocolCredentialPolicy {
fn is_sufficient(
&self,
protocol: &str,
valid_credentials: &[AuthCredential],
) -> CredentialPolicyResponse {
if let Some(policy) = self.protocols.get(protocol) {
policy.is_sufficient(protocol, valid_credentials)
} else {
self.default.is_sufficient(protocol, valid_credentials)
}
}
}

View file

@ -1,71 +1,66 @@
use std::time::{Duration, Instant}; use uuid::Uuid;
use once_cell::sync::Lazy;
use tracing::warn;
use super::{AuthCredential, CredentialPolicy, CredentialPolicyResponse}; use super::{AuthCredential, CredentialPolicy, CredentialPolicyResponse};
use crate::AuthResult; use crate::AuthResult;
#[allow(clippy::unwrap_used)]
pub static TIMEOUT: Lazy<Duration> = Lazy::new(|| Duration::from_secs(60 * 10));
pub struct AuthState { pub struct AuthState {
id: Uuid,
username: String, username: String,
protocol: String, protocol: String,
policy: Option<Box<dyn CredentialPolicy + Sync + Send>>, force_rejected: bool,
policy: Box<dyn CredentialPolicy + Sync + Send>,
valid_credentials: Vec<AuthCredential>, valid_credentials: Vec<AuthCredential>,
started_at: Instant,
} }
impl AuthState { impl AuthState {
pub fn new( pub(crate) fn new(
id: Uuid,
username: String, username: String,
protocol: String, protocol: String,
policy: Option<Box<dyn CredentialPolicy + Sync + Send>>, policy: Box<dyn CredentialPolicy + Sync + Send>,
) -> Self { ) -> Self {
Self { Self {
id,
username, username,
protocol, protocol,
force_rejected: false,
policy, policy,
valid_credentials: vec![], valid_credentials: vec![],
started_at: Instant::now(),
} }
} }
pub fn id(&self) -> &Uuid {
&self.id
}
pub fn username(&self) -> &str { pub fn username(&self) -> &str {
&self.username &self.username
} }
pub fn protocol(&self) -> &str {
&self.protocol
}
pub fn add_valid_credential(&mut self, credential: AuthCredential) { pub fn add_valid_credential(&mut self, credential: AuthCredential) {
self.valid_credentials.push(credential); self.valid_credentials.push(credential);
} }
pub fn is_expired(&self) -> bool { pub fn reject(&mut self) {
self.started_at.elapsed() > *TIMEOUT self.force_rejected = true;
} }
pub fn verify(&self) -> AuthResult { pub fn verify(&self) -> AuthResult {
if self.valid_credentials.is_empty() { if self.force_rejected {
warn!(
username=%self.username,
"No matching valid credentials"
);
return AuthResult::Rejected; return AuthResult::Rejected;
} }
match self
if let Some(ref policy) = self.policy { .policy
match policy.is_sufficient(&self.protocol, &self.valid_credentials[..]) { .is_sufficient(&self.protocol, &self.valid_credentials[..])
CredentialPolicyResponse::Ok => {} {
CredentialPolicyResponse::Need(kind) => { CredentialPolicyResponse::Ok => AuthResult::Accepted {
return AuthResult::Need(kind); username: self.username.clone(),
} },
CredentialPolicyResponse::NeedMoreCredentials => { CredentialPolicyResponse::Need(kinds) => AuthResult::Need(kinds),
return AuthResult::Rejected;
}
}
}
AuthResult::Accepted {
username: self.username.clone(),
} }
} }
} }

View file

@ -1,15 +1,32 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex; use once_cell::sync::Lazy;
use tokio::sync::{broadcast, Mutex};
use uuid::Uuid; use uuid::Uuid;
use super::AuthState; use super::AuthState;
use crate::{ConfigProvider, WarpgateError}; use crate::{AuthResult, ConfigProvider, WarpgateError};
#[allow(clippy::unwrap_used)]
pub static TIMEOUT: Lazy<Duration> = Lazy::new(|| Duration::from_secs(60 * 10));
struct AuthCompletionSignal {
sender: broadcast::Sender<AuthResult>,
created_at: Instant,
}
impl AuthCompletionSignal {
pub fn is_expired(&self) -> bool {
self.created_at.elapsed() > *TIMEOUT
}
}
pub struct AuthStateStore { pub struct AuthStateStore {
config_provider: Arc<Mutex<dyn ConfigProvider + Send + 'static>>, config_provider: Arc<Mutex<dyn ConfigProvider + Send + 'static>>,
store: HashMap<Uuid, AuthState>, store: HashMap<Uuid, (Arc<Mutex<AuthState>>, Instant)>,
completion_signals: HashMap<Uuid, AuthCompletionSignal>,
} }
impl AuthStateStore { impl AuthStateStore {
@ -17,47 +34,66 @@ impl AuthStateStore {
Self { Self {
store: HashMap::new(), store: HashMap::new(),
config_provider, config_provider,
completion_signals: HashMap::new(),
} }
} }
pub fn contains_key(&mut self, id: &Uuid) -> bool { pub fn contains_key(&self, id: &Uuid) -> bool {
self.store.contains_key(id) self.store.contains_key(id)
} }
pub fn get_mut(&mut self, id: &Uuid) -> Option<&mut AuthState> { pub fn get(&self, id: &Uuid) -> Option<Arc<Mutex<AuthState>>> {
self.store.get_mut(id) self.store.get(id).map(|x| x.0.clone())
} }
pub async fn create( pub async fn create(
&mut self, &mut self,
username: &str, username: &str,
protocol: &str, protocol: &str,
) -> Result<(Uuid, &mut AuthState), WarpgateError> { ) -> Result<(Uuid, Arc<Mutex<AuthState>>), WarpgateError> {
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let state = AuthState::new( let Some(policy) = self.config_provider
username.to_string(), .lock()
protocol.to_string(), .await
self.config_provider .get_credential_policy(username)
.lock() .await? else {
.await return Err(WarpgateError::UserNotFound)
.get_credential_policy(username) };
.await?,
); let state = AuthState::new(id, username.to_string(), protocol.to_string(), policy);
self.store.insert(id, state); self.store
.insert(id, (Arc::new(Mutex::new(state)), Instant::now()));
#[allow(clippy::unwrap_used)] #[allow(clippy::unwrap_used)]
Ok((id, self.store.get_mut(&id).unwrap())) Ok((id, self.get(&id).unwrap()))
}
pub fn subscribe(&mut self, id: Uuid) -> broadcast::Receiver<AuthResult> {
let signal = self.completion_signals.entry(id).or_insert_with(|| {
let (sender, _) = broadcast::channel(1);
AuthCompletionSignal {
sender,
created_at: Instant::now(),
}
});
signal.sender.subscribe()
}
pub async fn complete(&mut self, id: &Uuid) {
let Some((state, _)) = self.store.get(id) else {
return
};
if let Some(sig) = self.completion_signals.remove(id) {
let _ = sig.sender.send(state.lock().await.verify());
}
} }
pub async fn vacuum(&mut self) { pub async fn vacuum(&mut self) {
let mut to_remove = vec![]; self.store
for (id, state) in self.store.iter() { .retain(|_, (_, started_at)| started_at.elapsed() < *TIMEOUT);
if state.is_expired() {
to_remove.push(*id); self.completion_signals
} .retain(|_, signal| !signal.is_expired());
}
for id in to_remove {
self.store.remove(&id);
}
} }
} }

View file

@ -5,11 +5,12 @@ use std::time::Duration;
use poem_openapi::{Enum, Object, Union}; use poem_openapi::{Enum, Object, Union};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use url::Url;
use warpgate_sso::SsoProviderConfig; use warpgate_sso::SsoProviderConfig;
use crate::auth::CredentialKind; use crate::auth::CredentialKind;
use crate::helpers::otp::OtpSecretKey; use crate::helpers::otp::OtpSecretKey;
use crate::{ListenEndpoint, Secret}; use crate::{ListenEndpoint, Secret, WarpgateError};
const fn _default_true() -> bool { const fn _default_true() -> bool {
true true
@ -426,3 +427,24 @@ pub struct WarpgateConfig {
pub store: WarpgateConfigStore, pub store: WarpgateConfigStore,
pub paths_relative_to: PathBuf, pub paths_relative_to: PathBuf,
} }
impl WarpgateConfig {
pub fn construct_external_url(
&self,
fallback_host: Option<&str>,
) -> Result<Url, WarpgateError> {
let ext_host = self.store.external_host.as_deref().or(fallback_host);
let Some(ext_host) = ext_host else {
return Err(WarpgateError::ExternalHostNotSet);
};
let ext_port = self.store.http.listen.port();
let mut url = Url::parse(&format!("https://{ext_host}/"))?;
if ext_port != 443 {
let _ = url.set_port(Some(ext_port));
}
Ok(url)
}
}

View file

@ -1,4 +1,4 @@
use std::collections::HashSet; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
@ -11,7 +11,10 @@ use uuid::Uuid;
use warpgate_db_entities::Ticket; use warpgate_db_entities::Ticket;
use super::ConfigProvider; use super::ConfigProvider;
use crate::auth::{AuthCredential, CredentialPolicy}; use crate::auth::{
AllCredentialsPolicy, AnySingleCredentialPolicy, AuthCredential, CredentialKind,
CredentialPolicy, PerProtocolCredentialPolicy,
};
use crate::helpers::hash::verify_password_hash; use crate::helpers::hash::verify_password_hash;
use crate::helpers::otp::verify_totp; use crate::helpers::otp::verify_totp;
use crate::{Target, User, UserAuthCredential, UserSnapshot, WarpgateConfig, WarpgateError}; use crate::{Target, User, UserAuthCredential, UserSnapshot, WarpgateConfig, WarpgateError};
@ -78,9 +81,52 @@ impl ConfigProvider for FileConfigProvider {
return Ok(None); return Ok(None);
}; };
Ok(user let supported_credential_types: HashSet<CredentialKind> =
.require user.credentials.iter().map(|x| x.kind()).collect();
.map(|r| Box::new(r) as Box<dyn CredentialPolicy + Sync + Send>)) let default_policy = Box::new(AnySingleCredentialPolicy {
supported_credential_types: supported_credential_types.clone(),
}) as Box<dyn CredentialPolicy + Sync + Send>;
if let Some(req) = user.require {
let mut policy = PerProtocolCredentialPolicy {
default: default_policy,
protocols: HashMap::new(),
};
if let Some(p) = req.http {
policy.protocols.insert(
"HTTP",
Box::new(AllCredentialsPolicy {
supported_credential_types: supported_credential_types.clone(),
required_credential_types: p.into_iter().collect(),
}),
);
}
if let Some(p) = req.mysql {
policy.protocols.insert(
"MySQL",
Box::new(AllCredentialsPolicy {
supported_credential_types: supported_credential_types.clone(),
required_credential_types: p.into_iter().collect(),
}),
);
}
if let Some(p) = req.ssh {
policy.protocols.insert(
"SSH",
Box::new(AllCredentialsPolicy {
supported_credential_types,
required_credential_types: p.into_iter().collect(),
}),
);
}
Ok(Some(
Box::new(policy) as Box<dyn CredentialPolicy + Sync + Send>
))
} else {
Ok(Some(default_policy))
}
} }
async fn username_for_sso_credential( async fn username_for_sso_credential(
@ -193,6 +239,7 @@ impl ConfigProvider for FileConfigProvider {
} }
return Ok(false); return Ok(false);
} }
_ => return Err(WarpgateError::InvalidCredentialType),
} }
} }

View file

@ -1,4 +1,5 @@
mod file; mod file;
use std::collections::HashSet;
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
@ -12,11 +13,10 @@ use warpgate_db_entities::Ticket;
use crate::auth::{AuthCredential, CredentialKind, CredentialPolicy}; use crate::auth::{AuthCredential, CredentialKind, CredentialPolicy};
use crate::{Secret, Target, UserSnapshot, WarpgateError}; use crate::{Secret, Target, UserSnapshot, WarpgateError};
#[derive(Debug)] #[derive(Debug, Clone)]
pub enum AuthResult { pub enum AuthResult {
Accepted { username: String }, Accepted { username: String },
Need(CredentialKind), Need(HashSet<CredentialKind>),
NeedMoreCredentials,
Rejected, Rejected,
} }

View file

@ -9,8 +9,16 @@ pub enum WarpgateError {
DatabaseError(#[from] sea_orm::DbErr), DatabaseError(#[from] sea_orm::DbErr),
#[error("ticket not found: {0}")] #[error("ticket not found: {0}")]
InvalidTicket(Uuid), InvalidTicket(Uuid),
#[error("invalid credential type")]
InvalidCredentialType,
#[error(transparent)] #[error(transparent)]
Other(Box<dyn Error + Send + Sync>), Other(Box<dyn Error + Send + Sync>),
#[error("user not found")]
UserNotFound,
#[error("failed to parse url: {0}")]
UrlParse(#[from] url::ParseError),
#[error("external_url config option is not set")]
ExternalHostNotSet,
} }
impl ResponseError for WarpgateError { impl ResponseError for WarpgateError {

View file

@ -7,6 +7,7 @@ version = "0.4.0"
[dependencies] [dependencies]
anyhow = "1.0" anyhow = "1.0"
async-trait = "0.1" async-trait = "0.1"
chrono = {version = "0.4", features = ["serde"]}
cookie = "0.16" cookie = "0.16"
data-encoding = "2.3" data-encoding = "2.3"
delegate = "0.6" delegate = "0.6"

View file

@ -3,14 +3,18 @@ use std::sync::Arc;
use poem::session::Session; use poem::session::Session;
use poem::web::Data; use poem::web::Data;
use poem::Request; use poem::Request;
use poem_openapi::param::Path;
use poem_openapi::payload::Json; use poem_openapi::payload::Json;
use poem_openapi::{ApiResponse, Enum, Object, OpenApi}; use poem_openapi::{ApiResponse, Enum, Object, OpenApi};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::*; use tracing::*;
use warpgate_common::auth::{AuthCredential, CredentialKind}; use uuid::Uuid;
use warpgate_common::auth::{AuthCredential, AuthState, CredentialKind};
use warpgate_common::{AuthResult, Secret, Services}; use warpgate_common::{AuthResult, Secret, Services};
use crate::common::{authorize_session, get_auth_state_for_request, SessionExt}; use crate::common::{
authorize_session, endpoint_auth, get_auth_state_for_request, SessionAuthorization, SessionExt,
};
use crate::session::SessionStore; use crate::session::SessionStore;
pub struct Api; pub struct Api;
@ -33,6 +37,8 @@ enum ApiAuthState {
PasswordNeeded, PasswordNeeded,
OtpNeeded, OtpNeeded,
SsoNeeded, SsoNeeded,
WebUserApprovalNeeded,
PublicKeyNeeded,
Success, Success,
} }
@ -58,6 +64,7 @@ enum LogoutResponse {
#[derive(Object)] #[derive(Object)]
struct AuthStateResponseInternal { struct AuthStateResponseInternal {
pub protocol: String,
pub state: ApiAuthState, pub state: ApiAuthState,
} }
@ -65,17 +72,22 @@ struct AuthStateResponseInternal {
enum AuthStateResponse { enum AuthStateResponse {
#[oai(status = 200)] #[oai(status = 200)]
Ok(Json<AuthStateResponseInternal>), Ok(Json<AuthStateResponseInternal>),
#[oai(status = 404)]
NotFound,
} }
impl From<AuthResult> for ApiAuthState { impl From<AuthResult> for ApiAuthState {
fn from(state: AuthResult) -> Self { fn from(state: AuthResult) -> Self {
match state { match state {
AuthResult::Rejected => ApiAuthState::Failed, AuthResult::Rejected => ApiAuthState::Failed,
AuthResult::Need(CredentialKind::Password) => ApiAuthState::PasswordNeeded, AuthResult::Need(kinds) => match kinds.iter().next() {
AuthResult::Need(CredentialKind::Otp) => ApiAuthState::OtpNeeded, Some(CredentialKind::Password) => ApiAuthState::PasswordNeeded,
AuthResult::Need(CredentialKind::Sso) => ApiAuthState::SsoNeeded, Some(CredentialKind::Otp) => ApiAuthState::OtpNeeded,
AuthResult::Need(CredentialKind::PublicKey) => ApiAuthState::Failed, Some(CredentialKind::Sso) => ApiAuthState::SsoNeeded,
AuthResult::NeedMoreCredentials => ApiAuthState::Failed, Some(CredentialKind::WebUserApproval) => ApiAuthState::WebUserApprovalNeeded,
Some(CredentialKind::PublicKey) => ApiAuthState::PublicKeyNeeded,
None => ApiAuthState::Failed,
},
AuthResult::Accepted { .. } => ApiAuthState::Success, AuthResult::Accepted { .. } => ApiAuthState::Success,
} }
} }
@ -92,8 +104,9 @@ impl Api {
body: Json<LoginRequest>, body: Json<LoginRequest>,
) -> poem::Result<LoginResponse> { ) -> poem::Result<LoginResponse> {
let mut auth_state_store = services.auth_state_store.lock().await; let mut auth_state_store = services.auth_state_store.lock().await;
let state = let state_arc =
get_auth_state_for_request(&body.username, session, &mut auth_state_store).await?; get_auth_state_for_request(&body.username, session, &mut auth_state_store).await?;
let mut state = state_arc.lock().await;
let mut cp = services.config_provider.lock().await; let mut cp = services.config_provider.lock().await;
@ -107,6 +120,7 @@ impl Api {
match state.verify() { match state.verify() {
AuthResult::Accepted { username } => { AuthResult::Accepted { username } => {
auth_state_store.complete(state.id()).await;
authorize_session(req, username).await?; authorize_session(req, username).await?;
Ok(LoginResponse::Success) Ok(LoginResponse::Success)
} }
@ -131,12 +145,14 @@ impl Api {
let mut auth_state_store = services.auth_state_store.lock().await; let mut auth_state_store = services.auth_state_store.lock().await;
let Some(state) = state_id.and_then(|id| auth_state_store.get_mut(&id.0)) else { let Some(state_arc) = state_id.and_then(|id| auth_state_store.get(&id.0)) else {
return Ok(LoginResponse::Failure(Json(LoginFailureResponse { return Ok(LoginResponse::Failure(Json(LoginFailureResponse {
state: ApiAuthState::NotStarted, state: ApiAuthState::NotStarted,
}))) })))
}; };
let mut state = state_arc.lock().await;
let mut cp = services.config_provider.lock().await; let mut cp = services.config_provider.lock().await;
let otp_cred = AuthCredential::Otp(body.otp.clone().into()); let otp_cred = AuthCredential::Otp(body.otp.clone().into());
@ -146,6 +162,7 @@ impl Api {
match state.verify() { match state.verify() {
AuthResult::Accepted { username } => { AuthResult::Accepted { username } => {
auth_state_store.complete(state.id()).await;
authorize_session(req, username).await?; authorize_session(req, username).await?;
Ok(LoginResponse::Success) Ok(LoginResponse::Success)
} }
@ -155,27 +172,6 @@ impl Api {
} }
} }
#[oai(path = "/auth/state", method = "get", operation_id = "getAuthState")]
async fn api_auth_state(
&self,
session: &Session,
services: Data<&Services>,
) -> poem::Result<AuthStateResponse> {
let state_id = session.get_auth_state_id();
let mut auth_state_store = services.auth_state_store.lock().await;
let Some(state) = state_id.and_then(|id| auth_state_store.get_mut(&id.0)) else {
return Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal {
state: ApiAuthState::NotStarted,
})));
};
Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal {
state: state.verify().into(),
})))
}
#[oai(path = "/auth/logout", method = "post", operation_id = "logout")] #[oai(path = "/auth/logout", method = "post", operation_id = "logout")]
async fn api_auth_logout( async fn api_auth_logout(
&self, &self,
@ -187,4 +183,129 @@ impl Api {
info!("Logged out"); info!("Logged out");
Ok(LogoutResponse::Success) Ok(LogoutResponse::Success)
} }
#[oai(
path = "/auth/state",
method = "get",
operation_id = "getDefaultAuthState"
)]
async fn api_default_auth_state(
&self,
session: &Session,
services: Data<&Services>,
) -> poem::Result<AuthStateResponse> {
let Some(state_id) = session.get_auth_state_id() else {
return Ok(AuthStateResponse::NotFound)
};
let store = services.auth_state_store.lock().await;
let Some(state_arc) = store.get(&state_id.0) else {
return Ok(AuthStateResponse::NotFound);
};
serialize_auth_state_inner(state_arc).await
}
#[oai(
path = "/auth/state/:id",
method = "get",
operation_id = "get_auth_state",
transform = "endpoint_auth"
)]
async fn api_auth_state(
&self,
services: Data<&Services>,
auth: Option<Data<&SessionAuthorization>>,
id: Path<Uuid>,
) -> poem::Result<AuthStateResponse> {
let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else {
return Ok(AuthStateResponse::NotFound);
};
serialize_auth_state_inner(state_arc).await
}
#[oai(
path = "/auth/state/:id/approve",
method = "post",
operation_id = "approve_auth",
transform = "endpoint_auth"
)]
async fn api_approve_auth(
&self,
services: Data<&Services>,
auth: Option<Data<&SessionAuthorization>>,
id: Path<Uuid>,
) -> poem::Result<AuthStateResponse> {
let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else {
return Ok(AuthStateResponse::NotFound);
};
let auth_result = {
let mut state = state_arc.lock().await;
state.add_valid_credential(AuthCredential::WebUserApproval);
state.verify()
};
if let AuthResult::Accepted { .. } = auth_result {
services.auth_state_store.lock().await.complete(&*id).await;
}
serialize_auth_state_inner(state_arc).await
}
#[oai(
path = "/auth/state/:id/reject",
method = "post",
operation_id = "reject_auth",
transform = "endpoint_auth"
)]
async fn api_reject_auth(
&self,
services: Data<&Services>,
auth: Option<Data<&SessionAuthorization>>,
id: Path<Uuid>,
) -> poem::Result<AuthStateResponse> {
let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else {
return Ok(AuthStateResponse::NotFound);
};
state_arc.lock().await.reject();
services.auth_state_store.lock().await.complete(&*id).await;
serialize_auth_state_inner(state_arc).await
}
}
async fn get_auth_state(
id: &Uuid,
services: &Services,
auth: Option<&SessionAuthorization>,
) -> Option<Arc<Mutex<AuthState>>> {
let store = services.auth_state_store.lock().await;
let Some(auth) = auth else {
return None;
};
let SessionAuthorization::User(username) = auth else {
return None;
};
let Some(state_arc) = store.get(&*id) else {
return None;
};
{
let state = state_arc.lock().await;
if state.username() != username {
return None;
}
}
Some(state_arc)
}
async fn serialize_auth_state_inner(
state_arc: Arc<Mutex<AuthState>>,
) -> poem::Result<AuthStateResponse> {
let state = state_arc.lock().await;
Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal {
protocol: state.protocol().to_string(),
state: state.verify().into(),
})))
} }

View file

@ -1,10 +1,9 @@
use poem::session::Session; use poem::session::Session;
use poem::web::Data; use poem::web::Data;
use poem::Request; use poem::Request;
use poem_openapi::param::Path; use poem_openapi::param::{Path, Query};
use poem_openapi::payload::Json; use poem_openapi::payload::Json;
use poem_openapi::{ApiResponse, Object, OpenApi}; use poem_openapi::{ApiResponse, Object, OpenApi};
use reqwest::Url;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use warpgate_common::Services; use warpgate_common::Services;
use warpgate_sso::{SsoClient, SsoLoginRequest}; use warpgate_sso::{SsoClient, SsoLoginRequest};
@ -31,6 +30,7 @@ pub static SSO_CONTEXT_SESSION_KEY: &str = "sso_request";
pub struct SsoContext { pub struct SsoContext {
pub provider: String, pub provider: String,
pub request: SsoLoginRequest, pub request: SsoLoginRequest,
pub next_url: Option<String>,
} }
#[OpenApi] #[OpenApi]
@ -46,31 +46,14 @@ impl Api {
session: &Session, session: &Session,
services: Data<&Services>, services: Data<&Services>,
name: Path<String>, name: Path<String>,
next: Query<Option<String>>,
) -> poem::Result<StartSsoResponse> { ) -> poem::Result<StartSsoResponse> {
let config = services.config.lock().await; let config = services.config.lock().await;
let name = name.0; let name = name.0;
let ext_host = config
.store
.external_host
.as_deref()
.or_else(|| req.original_uri().host());
let Some(ext_host) = ext_host else {
return Err(poem::Error::from_string("external_host config option is required for SSO", http::status::StatusCode::INTERNAL_SERVER_ERROR));
};
let ext_port = config.store.http.listen.port();
let mut return_url = Url::parse(&format!("https://{ext_host}/@warpgate/api/sso/return")) let mut return_url = config.construct_external_url(req.original_uri().host())?;
.map_err(|e| { return_url.set_path("@warpgate/api/sso/return");
poem::Error::from_string(
format!("failed to construct the return URL: {e}"),
http::status::StatusCode::INTERNAL_SERVER_ERROR,
)
})?;
if ext_port != 443 {
let _ = return_url.set_port(Some(ext_port));
}
let Some(provider_config) = config.store.sso_providers.iter().find(|p| p.name == *name) else { let Some(provider_config) = config.store.sso_providers.iter().find(|p| p.name == *name) else {
return Ok(StartSsoResponse::NotFound); return Ok(StartSsoResponse::NotFound);
@ -84,10 +67,14 @@ impl Api {
.map_err(poem::error::InternalServerError)?; .map_err(poem::error::InternalServerError)?;
let url = sso_req.auth_url().to_string(); let url = sso_req.auth_url().to_string();
session.set(SSO_CONTEXT_SESSION_KEY, SsoContext { session.set(
provider: name, SSO_CONTEXT_SESSION_KEY,
request: sso_req, SsoContext {
}); provider: name,
request: sso_req,
next_url: next.0.clone(),
},
);
Ok(StartSsoResponse::Ok(Json(StartSsoResponseParams { url }))) Ok(StartSsoResponse::Ok(Json(StartSsoResponseParams { url })))
} }

View file

@ -120,8 +120,9 @@ impl Api {
}; };
let mut auth_state_store = services.auth_state_store.lock().await; let mut auth_state_store = services.auth_state_store.lock().await;
let state = get_auth_state_for_request(&username, session, &mut auth_state_store).await?; let state_arc = get_auth_state_for_request(&username, session, &mut auth_state_store).await?;
let mut state = state_arc.lock().await;
let mut cp = services.config_provider.lock().await; let mut cp = services.config_provider.lock().await;
if cp.validate_credential(&username, &cred).await? { if cp.validate_credential(&username, &cred).await? {
@ -130,11 +131,15 @@ impl Api {
match state.verify() { match state.verify() {
AuthResult::Accepted { username } => { AuthResult::Accepted { username } => {
auth_state_store.complete(state.id()).await;
authorize_session(req, username).await?; authorize_session(req, username).await?;
} }
_ => () _ => (),
} }
Ok(Response::new(ReturnToSsoResponse::Ok).header("Location", "/@warpgate")) Ok(Response::new(ReturnToSsoResponse::Ok).header(
"Location",
context.next_url.as_deref().unwrap_or("/@warpgate#/login"),
))
} }
} }

View file

@ -170,18 +170,18 @@ pub fn gateway_redirect(req: &Request) -> Response {
.unwrap_or("".into()); .unwrap_or("".into());
let path = format!( let path = format!(
"/@warpgate?next={}", "/@warpgate#/login?next={}",
utf8_percent_encode(&path, NON_ALPHANUMERIC), utf8_percent_encode(&path, NON_ALPHANUMERIC),
); );
Redirect::temporary(path).into_response() Redirect::temporary(path).into_response()
} }
pub async fn get_auth_state_for_request<'a>( pub async fn get_auth_state_for_request(
username: &str, username: &str,
session: &Session, session: &Session,
store: &'a mut AuthStateStore, store: &mut AuthStateStore,
) -> Result<&'a mut AuthState, WarpgateError> { ) -> Result<Arc<Mutex<AuthState>>, WarpgateError> {
match session.get_auth_state_id() { match session.get_auth_state_id() {
Some(id) => { Some(id) => {
if !store.contains_key(&id.0) { if !store.contains_key(&id.0) {
@ -192,7 +192,7 @@ pub async fn get_auth_state_for_request<'a>(
}; };
match session.get_auth_state_id() { match session.get_auth_state_id() {
Some(id) => Ok(store.get_mut(&id.0).unwrap()), Some(id) => Ok(store.get(&id.0).unwrap()),
None => { None => {
let (id, state) = store let (id, state) = store
.create(&username, crate::common::PROTOCOL_NAME) .create(&username, crate::common::PROTOCOL_NAME)

View file

@ -7,7 +7,7 @@ use tokio::net::TcpStream;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::*; use tracing::*;
use uuid::Uuid; use uuid::Uuid;
use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState}; use warpgate_common::auth::{AuthCredential, AuthSelector};
use warpgate_common::helpers::rng::get_crypto_rng; use warpgate_common::helpers::rng::get_crypto_rng;
use warpgate_common::{ use warpgate_common::{
authorize_ticket, AuthResult, Secret, Services, TargetMySqlOptions, TargetOptions, authorize_ticket, AuthResult, Secret, Services, TargetMySqlOptions, TargetOptions,
@ -180,15 +180,20 @@ impl MySqlSession {
username, username,
target_name, target_name,
} => { } => {
let user_auth_result = { let state_arc = self
let mut cp = self.services.config_provider.lock().await; .services
.auth_state_store
.lock()
.await
.create(&username, crate::common::PROTOCOL_NAME)
.await?
.1;
let mut state = state_arc.lock().await;
let user_auth_result = {
let credential = AuthCredential::Password(password); let credential = AuthCredential::Password(password);
let mut state = AuthState::new(
username.clone(), let mut cp = self.services.config_provider.lock().await;
crate::common::PROTOCOL_NAME.to_string(),
cp.get_credential_policy(&username).await?,
);
if cp.validate_credential(&username, &credential).await? { if cp.validate_credential(&username, &credential).await? {
state.add_valid_credential(credential); state.add_valid_credential(credential);
} }
@ -198,6 +203,12 @@ impl MySqlSession {
match user_auth_result { match user_auth_result {
AuthResult::Accepted { username } => { AuthResult::Accepted { username } => {
self.services
.auth_state_store
.lock()
.await
.complete(state.id())
.await;
let target_auth_result = { let target_auth_result = {
self.services self.services
.config_provider .config_provider
@ -216,9 +227,7 @@ impl MySqlSession {
} }
self.run_authorized(handshake, username, target_name).await self.run_authorized(handshake, username, target_name).await
} }
AuthResult::Rejected AuthResult::Rejected | AuthResult::Need(_) => fail(&mut self).await, // TODO SSO
| AuthResult::Need(_)
| AuthResult::NeedMoreCredentials => fail(&mut self).await, // TODO SSO
} }
} }
AuthSelector::Ticket { secret } => { AuthSelector::Ticket { secret } => {

View file

@ -12,7 +12,7 @@ bimap = "0.6"
bytes = "1.2" bytes = "1.2"
dialoguer = "0.10" dialoguer = "0.10"
futures = "0.3" futures = "0.3"
russh = {version = "0.34.0-beta.7", features = ["openssl"]} russh = {version = "0.34.0-beta.8", features = ["openssl"]}
russh-keys = {version = "0.22.0-beta.4", features = ["openssl"]} russh-keys = {version = "0.22.0-beta.4", features = ["openssl"]}
sea-orm = {version = "^0.9", features = ["runtime-tokio-native-tls"], default-features = false} sea-orm = {version = "^0.9", features = ["runtime-tokio-native-tls"], default-features = false}
thiserror = "1.0" thiserror = "1.0"

View file

@ -23,6 +23,7 @@ pub async fn run_server(services: Services, address: SocketAddr) -> Result<()> {
let config = services.config.lock().await; let config = services.config.lock().await;
russh::server::Config { russh::server::Config {
auth_rejection_time: std::time::Duration::from_secs(1), auth_rejection_time: std::time::Duration::from_secs(1),
connection_timeout: Some(std::time::Duration::from_secs(300)),
methods: MethodSet::PUBLICKEY | MethodSet::PASSWORD | MethodSet::KEYBOARD_INTERACTIVE, methods: MethodSet::PUBLICKEY | MethodSet::PASSWORD | MethodSet::KEYBOARD_INTERACTIVE,
keys: load_host_keys(&config)?, keys: load_host_keys(&config)?,
..Default::default() ..Default::default()

View file

@ -1,6 +1,6 @@
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::hash_map::Entry::Vacant; use std::collections::hash_map::Entry::Vacant;
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::net::{Ipv4Addr, SocketAddr}; use std::net::{Ipv4Addr, SocketAddr};
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
@ -10,11 +10,11 @@ use anyhow::{Context, Result};
use bimap::BiMap; use bimap::BiMap;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use russh::server::Session; use russh::server::Session;
use russh::{CryptoVec, Sig}; use russh::{CryptoVec, MethodSet, Sig};
use russh_keys::key::PublicKey; use russh_keys::key::PublicKey;
use russh_keys::PublicKeyBase64; use russh_keys::PublicKeyBase64;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{oneshot, Mutex}; use tokio::sync::{broadcast, oneshot, Mutex};
use tracing::*; use tracing::*;
use uuid::Uuid; use uuid::Uuid;
use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState, CredentialKind}; use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState, CredentialKind};
@ -53,6 +53,12 @@ enum Event {
Client(RCEvent), Client(RCEvent),
} }
enum KeyboardInteractiveState {
None,
OtpRequested,
WebAuthRequested(broadcast::Receiver<AuthResult>),
}
pub struct ServerSession { pub struct ServerSession {
pub id: SessionId, pub id: SessionId,
username: Option<String>, username: Option<String>,
@ -74,7 +80,8 @@ pub struct ServerSession {
hub: EventHub<Event>, hub: EventHub<Event>,
event_sender: EventSender<Event>, event_sender: EventSender<Event>,
service_output: ServiceOutput, service_output: ServiceOutput,
auth_state: Option<AuthState>, auth_state: Option<Arc<Mutex<AuthState>>>,
keyboard_interactive_state: KeyboardInteractiveState,
} }
fn session_debug_tag(id: &SessionId, remote_address: &SocketAddr) -> String { fn session_debug_tag(id: &SessionId, remote_address: &SocketAddr) -> String {
@ -142,6 +149,7 @@ impl ServerSession {
so_tx.send(BytesMut::from(data).freeze()).context("x") so_tx.send(BytesMut::from(data).freeze()).context("x")
})), })),
auth_state: None, auth_state: None,
keyboard_interactive_state: KeyboardInteractiveState::None,
}; };
let this = Arc::new(Mutex::new(this)); let this = Arc::new(Mutex::new(this));
@ -217,18 +225,23 @@ impl ServerSession {
Ok(this) Ok(this)
} }
async fn get_auth_state(&mut self, username: &str) -> Result<&mut AuthState> { async fn get_auth_state(&mut self, username: &str) -> Result<Arc<Mutex<AuthState>>> {
#[allow(clippy::unwrap_used)] #[allow(clippy::unwrap_used)]
if self.auth_state.is_none() || self.auth_state.as_ref().unwrap().username() != username { if self.auth_state.is_none()
let mut cp = self.services.config_provider.lock().await; || self.auth_state.as_ref().unwrap().lock().await.username() != username
self.auth_state = Some(AuthState::new( {
username.to_string(), let state = self
crate::PROTOCOL_NAME.to_string(), .services
cp.get_credential_policy(username).await?, .auth_state_store
)); .lock()
.await
.create(username, crate::PROTOCOL_NAME)
.await?
.1;
self.auth_state = Some(state);
} }
#[allow(clippy::unwrap_used)] #[allow(clippy::unwrap_used)]
Ok(self.auth_state.as_mut().unwrap()) Ok(self.auth_state.as_ref().map(Clone::clone).unwrap())
} }
pub fn make_logging_span(&self) -> tracing::Span { pub fn make_logging_span(&self) -> tracing::Span {
@ -939,13 +952,17 @@ impl ServerSession {
.await .await
{ {
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept, Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
Ok(AuthResult::Rejected) => russh::server::Auth::Reject, Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
Ok(AuthResult::Need(_) | AuthResult::NeedMoreCredentials) => { proceed_with_methods: Some(MethodSet::all()),
russh::server::Auth::Reject },
} Ok(AuthResult::Need(kinds)) => russh::server::Auth::Reject {
proceed_with_methods: Some(self.get_remaining_auth_methods(kinds)),
},
Err(error) => { Err(error) => {
error!(?error, "Failed to verify credentials"); error!(?error, "Failed to verify credentials");
russh::server::Auth::Reject russh::server::Auth::Reject {
proceed_with_methods: None,
}
} }
} }
} }
@ -963,13 +980,17 @@ impl ServerSession {
.await .await
{ {
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept, Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
Ok(AuthResult::Rejected) => russh::server::Auth::Reject, Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
Ok(AuthResult::Need(_) | AuthResult::NeedMoreCredentials) => { proceed_with_methods: None,
russh::server::Auth::Reject },
} Ok(AuthResult::Need(_)) => russh::server::Auth::Reject {
proceed_with_methods: None,
},
Err(error) => { Err(error) => {
error!(?error, "Failed to verify credentials"); error!(?error, "Failed to verify credentials");
russh::server::Auth::Reject russh::server::Auth::Reject {
proceed_with_methods: None,
}
} }
} }
} }
@ -982,27 +1003,111 @@ impl ServerSession {
let selector: AuthSelector = ssh_username.expose_secret().into(); let selector: AuthSelector = ssh_username.expose_secret().into();
info!("Keyboard-interactive auth as {:?}", selector); info!("Keyboard-interactive auth as {:?}", selector);
let cred = response.map(AuthCredential::Otp); let cred;
match &mut self.keyboard_interactive_state {
KeyboardInteractiveState::None => {
cred = None;
}
KeyboardInteractiveState::OtpRequested => {
cred = response.map(AuthCredential::Otp);
}
KeyboardInteractiveState::WebAuthRequested(event) => {
cred = None;
let _ = event.recv().await;
// the auth state has been updated by now
}
}
self.keyboard_interactive_state = KeyboardInteractiveState::None;
match self.try_auth(&selector, cred).await { match self.try_auth(&selector, cred).await {
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept, Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
Ok( Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
AuthResult::Rejected proceed_with_methods: None,
| AuthResult::NeedMoreCredentials
| AuthResult::Need(CredentialKind::Otp),
) => russh::server::Auth::Partial {
name: Cow::Borrowed("Two-factor authentication"),
instructions: Cow::Borrowed(""),
prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]),
}, },
Ok(AuthResult::Need(_)) => russh::server::Auth::Reject, // TODO SSO Ok(AuthResult::Need(kinds)) => {
if kinds.contains(&CredentialKind::Otp) {
self.keyboard_interactive_state = KeyboardInteractiveState::OtpRequested;
russh::server::Auth::Partial {
name: Cow::Borrowed("Two-factor authentication"),
instructions: Cow::Borrowed(""),
prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]),
}
} else if kinds.contains(&CredentialKind::WebUserApproval) {
let Some(auth_state) = self.auth_state.as_ref() else {
return russh::server::Auth::Reject { proceed_with_methods: None};
};
let auth_state_id = *auth_state.lock().await.id();
let event = self
.services
.auth_state_store
.lock()
.await
.subscribe(auth_state_id);
self.keyboard_interactive_state =
KeyboardInteractiveState::WebAuthRequested(event);
let mut login_url = match self
.services
.config
.lock()
.await
.construct_external_url(None)
{
Ok(url) => url,
Err(error) => {
error!(?error, "Failed to construct external URL");
return russh::server::Auth::Reject {
proceed_with_methods: None,
};
}
};
login_url.set_path("@warpgate");
login_url
.set_fragment(Some(&format!("/login?next=%2Flogin%2F{auth_state_id}")));
russh::server::Auth::Partial {
name: Cow::Owned(format!(
concat!(
"----------------------------------------------------------------\n",
"Warpgate authentication: please open {} in your browser\n",
"----------------------------------------------------------------\n"
),
login_url
)),
instructions: Cow::Borrowed(""),
prompts: Cow::Owned(vec![(Cow::Borrowed("Press Enter when done: "), true)]),
}
} else {
russh::server::Auth::Reject {
proceed_with_methods: None,
}
}
}
Err(error) => { Err(error) => {
error!(?error, "Failed to verify credentials"); error!(?error, "Failed to verify credentials");
russh::server::Auth::Reject russh::server::Auth::Reject {
proceed_with_methods: None,
}
} }
} }
} }
fn get_remaining_auth_methods(&self, kinds: HashSet<CredentialKind>) -> MethodSet {
let mut m = MethodSet::empty();
for kind in kinds {
match kind {
CredentialKind::Password => m.insert(MethodSet::PASSWORD),
CredentialKind::Otp => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
CredentialKind::WebUserApproval => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
CredentialKind::PublicKey => m.insert(MethodSet::PUBLICKEY),
CredentialKind::Sso => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
}
}
m
}
async fn try_auth( async fn try_auth(
&mut self, &mut self,
selector: &AuthSelector, selector: &AuthSelector,
@ -1014,7 +1119,9 @@ impl ServerSession {
target_name, target_name,
} => { } => {
let cp = self.services.config_provider.clone(); let cp = self.services.config_provider.clone();
let state = self.get_auth_state(username).await?;
let state_arc = self.get_auth_state(username).await?;
let mut state = state_arc.lock().await;
if let Some(credential) = credential { if let Some(credential) = credential {
if cp if cp
@ -1031,6 +1138,12 @@ impl ServerSession {
match user_auth_result { match user_auth_result {
AuthResult::Accepted { username } => { AuthResult::Accepted { username } => {
self.services
.auth_state_store
.lock()
.await
.complete(state.id())
.await;
let target_auth_result = { let target_auth_result = {
self.services self.services
.config_provider .config_provider

View file

@ -2,22 +2,25 @@
import { faSignOut } from '@fortawesome/free-solid-svg-icons' import { faSignOut } from '@fortawesome/free-solid-svg-icons'
import { Alert, Spinner } from 'sveltestrap' import { Alert, Spinner } from 'sveltestrap'
import Fa from 'svelte-fa' import Fa from 'svelte-fa'
import Router, { push } from 'svelte-spa-router'
import { wrap } from 'svelte-spa-router/wrap'
import { get } from 'svelte/store'
import { api } from 'gateway/lib/api' import { api } from 'gateway/lib/api'
import { reloadServerInfo, serverInfo } from 'gateway/lib/store' import { reloadServerInfo, serverInfo } from 'gateway/lib/store'
import ThemeSwitcher from 'common/ThemeSwitcher.svelte' import ThemeSwitcher from 'common/ThemeSwitcher.svelte'
import Login from './Login.svelte'
import TargetList from './TargetList.svelte'
import Logo from 'common/Logo.svelte' import Logo from 'common/Logo.svelte'
let redirecting = false let redirecting = false
let serverInfoPromise = reloadServerInfo()
async function init () { async function init () {
await reloadServerInfo() await serverInfoPromise
} }
async function logout () { async function logout () {
await api.logout() await api.logout()
await reloadServerInfo() await reloadServerInfo()
push('/login')
} }
function onPageResume () { function onPageResume () {
@ -25,6 +28,36 @@ function onPageResume () {
init() init()
} }
async function requireLogin (detail) {
await serverInfoPromise
if (!get(serverInfo)?.username) {
let url = detail.location
if (detail.querystring) {
url += '?' + detail.querystring
}
push('/login?next=' + encodeURIComponent(url))
return false
}
return true
}
const routes = {
'/': wrap({
asyncComponent: () => import('./TargetList.svelte'),
props: {
'on:navigation': () => redirecting = true,
},
conditions: [requireLogin],
}),
'/login': wrap({
asyncComponent: () => import('./Login.svelte'),
}),
'/login/:stateId': wrap({
asyncComponent: () => import('./OutOfBandAuth.svelte'),
conditions: [requireLogin],
}),
}
init() init()
</script> </script>
@ -38,9 +71,9 @@ init()
<Spinner /> <Spinner />
{:else} {:else}
<div class="d-flex align-items-center mt-5 mb-5"> <div class="d-flex align-items-center mt-5 mb-5">
<div class="logo"> <a class="logo" href="/@warpgate">
<Logo /> <Logo />
</div> </a>
{#if $serverInfo?.username} {#if $serverInfo?.username}
<div class="ms-auto"> <div class="ms-auto">
@ -56,12 +89,7 @@ init()
</div> </div>
<main> <main>
{#if $serverInfo?.username} <Router {routes}/>
<TargetList
on:navigation={() => redirecting = true} />
{:else}
<Login />
{/if}
</main> </main>
<footer class="mt-5"> <footer class="mt-5">

View file

@ -1,5 +1,6 @@
<script lang="ts"> <script lang="ts">
import { replace } from 'svelte-spa-router' import { get } from 'svelte/store'
import { querystring, replace } from 'svelte-spa-router'
import { Alert, FormGroup, Spinner } from 'sveltestrap' import { Alert, FormGroup, Spinner } from 'sveltestrap'
import Fa from 'svelte-fa' import Fa from 'svelte-fa'
import { faArrowRight } from '@fortawesome/free-solid-svg-icons' import { faArrowRight } from '@fortawesome/free-solid-svg-icons'
@ -9,25 +10,47 @@ import { api, ApiAuthState, LoginFailureResponseFromJSON, SsoProviderDescription
import { reloadServerInfo } from 'gateway/lib/store' import { reloadServerInfo } from 'gateway/lib/store'
import AsyncButton from 'common/AsyncButton.svelte' import AsyncButton from 'common/AsyncButton.svelte'
export let params: { stateId?: string } = {}
let error: Error|null = null let error: Error|null = null
let username = '' let username = ''
let password = '' let password = ''
let otp = '' let otp = ''
let busy = false let busy = false
let authState = ApiAuthState.NotStarted let authState: ApiAuthState|undefined = undefined
let ssoProvidersPromise = api.getSsoProviders() let ssoProvidersPromise = api.getSsoProviders()
const nextURL = new URLSearchParams(location.search).get('next') const nextURL = new URLSearchParams(get(querystring)).get('next') ?? undefined
const serverErrorMessage = new URLSearchParams(location.search).get('login_error') const serverErrorMessage = new URLSearchParams(location.search).get('login_error')
async function init () { async function init () {
authState = (await api.getAuthState()).state try {
authState = (await api.getDefaultAuthState()).state
} catch (err) {
if (err.status) {
const response = err as Response
if (response.status === 404) {
authState = ApiAuthState.NotStarted
}
}
}
continueWithState() continueWithState()
} }
function success () {
if (nextURL) {
replace(nextURL)
} else {
replace('/')
}
}
async function continueWithState () { async function continueWithState () {
if (authState === ApiAuthState.Success) {
success()
}
if (authState === ApiAuthState.SsoNeeded) { if (authState === ApiAuthState.SsoNeeded) {
const providers = await ssoProvidersPromise const providers = await ssoProvidersPromise
if (!providers.length) { if (!providers.length) {
@ -65,18 +88,15 @@ async function _login () {
}, },
}) })
} }
if (nextURL) { await reloadServerInfo()
location.href = nextURL success()
} else {
await reloadServerInfo()
replace('/')
}
} catch (err) { } catch (err) {
if (err.status) { if (err.status) {
const response = err as Response const response = err as Response
if (response.status === 401) { if (response.status === 401) {
const failure = LoginFailureResponseFromJSON(await response.json()) const failure = LoginFailureResponseFromJSON(await response.json())
authState = failure.state authState = failure.state
continueWithState() continueWithState()
} else { } else {
error = new Error(await response.text()) error = new Error(await response.text())
@ -96,8 +116,8 @@ function onInputKey (event: KeyboardEvent) {
async function startSSO (provider: SsoProviderDescription) { async function startSSO (provider: SsoProviderDescription) {
busy = true busy = true
try { try {
const params = await api.startSso(provider) const p = await api.startSso({ name: provider.name, next: nextURL })
location.href = params.url location.href = p.url
} catch { } catch {
busy = false busy = false
} }
@ -115,6 +135,10 @@ async function startSSO (provider: SsoProviderDescription) {
<h1>Continue login</h1> <h1>Continue login</h1>
{/if} {/if}
</div> </div>
{#if params.stateId}
//todo
loggin in for auth state id {params.stateId}
{/if}
{#if authState === ApiAuthState.OtpNeeded} {#if authState === ApiAuthState.OtpNeeded}
<FormGroup floating label="One-time password"> <FormGroup floating label="One-time password">
<!-- svelte-ignore a11y-autofocus --> <!-- svelte-ignore a11y-autofocus -->

View file

@ -0,0 +1,65 @@
<script lang="ts">
import { Alert, Spinner } from 'sveltestrap'
import { api, ApiAuthState, AuthStateResponseInternal } from 'gateway/lib/api'
import AsyncButton from 'common/AsyncButton.svelte'
export let params: { stateId: string }
let authState: AuthStateResponseInternal
async function reload () {
authState = await api.getAuthState({ id: params.stateId })
}
async function init () {
await reload()
}
async function approve () {
api.approveAuth({ id: params.stateId })
await reload()
}
async function reject () {
api.rejectAuth({ id: params.stateId })
await reload()
}
</script>
{#await init()}
<Spinner />
{:then}
<div class="page-summary-bar">
<h1>Authorization request</h1>
</div>
<p>Authorize this {authState.protocol} session?</p>
{#if authState.state === ApiAuthState.Success}
<Alert color="success">
Approved
</Alert>
{:else if authState.state === ApiAuthState.Failed}
<Alert color="danger">
Rejected
</Alert>
{:else}
<div class="d-flex">
<AsyncButton
color="primary"
class="d-flex align-items-center ms-auto"
click={approve}
>
Authorize
</AsyncButton>
<AsyncButton
outline
color="secondary"
class="d-flex align-items-center ms-2"
click={reject}
>
Reject
</AsyncButton>
</div>
{/if}
{/await}

View file

@ -71,6 +71,16 @@
"operationId": "otpLogin" "operationId": "otpLogin"
} }
}, },
"/auth/logout": {
"post": {
"responses": {
"201": {
"description": ""
}
},
"operationId": "logout"
}
},
"/auth/state": { "/auth/state": {
"get": { "get": {
"responses": { "responses": {
@ -83,19 +93,108 @@
} }
} }
} }
} },
}, "404": {
"operationId": "getAuthState"
}
},
"/auth/logout": {
"post": {
"responses": {
"201": {
"description": "" "description": ""
} }
}, },
"operationId": "logout" "operationId": "getDefaultAuthState"
}
},
"/auth/state/{id}": {
"get": {
"parameters": [
{
"name": "id",
"schema": {
"type": "string",
"format": "uuid"
},
"in": "path",
"required": true,
"deprecated": false
}
],
"responses": {
"200": {
"description": "",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/AuthStateResponseInternal"
}
}
}
},
"404": {
"description": ""
}
},
"operationId": "get_auth_state"
}
},
"/auth/state/{id}/approve": {
"post": {
"parameters": [
{
"name": "id",
"schema": {
"type": "string",
"format": "uuid"
},
"in": "path",
"required": true,
"deprecated": false
}
],
"responses": {
"200": {
"description": "",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/AuthStateResponseInternal"
}
}
}
},
"404": {
"description": ""
}
},
"operationId": "approve_auth"
}
},
"/auth/state/{id}/reject": {
"post": {
"parameters": [
{
"name": "id",
"schema": {
"type": "string",
"format": "uuid"
},
"in": "path",
"required": true,
"deprecated": false
}
],
"responses": {
"200": {
"description": "",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/AuthStateResponseInternal"
}
}
}
},
"404": {
"description": ""
}
},
"operationId": "reject_auth"
} }
}, },
"/info": { "/info": {
@ -187,6 +286,15 @@
"in": "path", "in": "path",
"required": true, "required": true,
"deprecated": false "deprecated": false
},
{
"name": "next",
"schema": {
"type": "string"
},
"in": "query",
"required": false,
"deprecated": false
} }
], ],
"responses": { "responses": {
@ -218,15 +326,21 @@
"PasswordNeeded", "PasswordNeeded",
"OtpNeeded", "OtpNeeded",
"SsoNeeded", "SsoNeeded",
"WebUserApprovalNeeded",
"PublicKeyNeeded",
"Success" "Success"
] ]
}, },
"AuthStateResponseInternal": { "AuthStateResponseInternal": {
"type": "object", "type": "object",
"required": [ "required": [
"protocol",
"state" "state"
], ],
"properties": { "properties": {
"protocol": {
"type": "string"
},
"state": { "state": {
"$ref": "#/components/schemas/ApiAuthState" "$ref": "#/components/schemas/ApiAuthState"
} }

View file

@ -30,6 +30,9 @@ pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
.await .await
.with_context(|| "Checking MySQL key".to_string())?; .with_context(|| "Checking MySQL key".to_string())?;
} }
if !config.store.sso_providers.is_empty() && config.store.external_host.is_none() {
anyhow::bail!("SSO requires the external_host config option");
}
info!("No problems found"); info!("No problems found");
Ok(()) Ok(())
} }

View file

@ -77,7 +77,6 @@ fn check_and_migrate_config(store: &mut serde_yaml::Value) {
} }
} }
#[must_use]
pub fn watch_config<P: AsRef<Path> + Send + 'static>( pub fn watch_config<P: AsRef<Path> + Send + 'static>(
path: P, path: P,
config: Arc<Mutex<WarpgateConfig>>, config: Arc<Mutex<WarpgateConfig>>,