mirror of
https://github.com/warp-tech/warpgate.git
synced 2024-11-10 09:12:56 +08:00
Out-of-band SSO (#245)
This commit is contained in:
parent
fbd8d0dda3
commit
c6885f18c3
28 changed files with 863 additions and 248 deletions
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
|
@ -12,7 +12,6 @@ jobs:
|
||||||
|
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
toolchain: nightly
|
|
||||||
target: x86_64-unknown-linux-gnu
|
target: x86_64-unknown-linux-gnu
|
||||||
override: true
|
override: true
|
||||||
|
|
||||||
|
@ -42,7 +41,6 @@ jobs:
|
||||||
uses: actions-rs/cargo@v1
|
uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
command: build
|
command: build
|
||||||
toolchain: nightly
|
|
||||||
use-cross: true
|
use-cross: true
|
||||||
args: --release --target x86_64-unknown-linux-gnu -Ztarget-applies-to-host
|
args: --release --target x86_64-unknown-linux-gnu -Ztarget-applies-to-host
|
||||||
|
|
||||||
|
|
6
Cargo.lock
generated
6
Cargo.lock
generated
|
@ -3143,9 +3143,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "russh"
|
name = "russh"
|
||||||
version = "0.34.0-beta.7"
|
version = "0.34.0-beta.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4f3bb72a66e32d52e0e258627d141d5c93b408e050f15033699caa836d064c7e"
|
checksum = "ccd8be93ee0b54a8a6b74c77ecef946185f0acbfb5234ea66666887621381e85"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aes 0.8.1",
|
"aes 0.8.1",
|
||||||
"aes-gcm 0.10.1",
|
"aes-gcm 0.10.1",
|
||||||
|
@ -4620,6 +4620,7 @@ dependencies = [
|
||||||
"bytes 1.2.1",
|
"bytes 1.2.1",
|
||||||
"chrono",
|
"chrono",
|
||||||
"data-encoding",
|
"data-encoding",
|
||||||
|
"futures",
|
||||||
"humantime-serde",
|
"humantime-serde",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
@ -4691,6 +4692,7 @@ version = "0.4.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
"chrono",
|
||||||
"cookie",
|
"cookie",
|
||||||
"data-encoding",
|
"data-encoding",
|
||||||
"delegate",
|
"delegate",
|
||||||
|
|
1
rust-toolchain
Normal file
1
rust-toolchain
Normal file
|
@ -0,0 +1 @@
|
||||||
|
nightly-2022-08-01
|
|
@ -1,2 +0,0 @@
|
||||||
[toolchain]
|
|
||||||
channel = "nightly-2022-07-22"
|
|
|
@ -13,6 +13,7 @@ chrono = { version = "0.4", features = ["serde"] }
|
||||||
data-encoding = "2.3"
|
data-encoding = "2.3"
|
||||||
humantime-serde = "1.1"
|
humantime-serde = "1.1"
|
||||||
lazy_static = "1.4"
|
lazy_static = "1.4"
|
||||||
|
futures = "0.3"
|
||||||
once_cell = "1.10"
|
once_cell = "1.10"
|
||||||
packet = "0.1"
|
packet = "0.1"
|
||||||
password-hash = "0.4"
|
password-hash = "0.4"
|
||||||
|
|
|
@ -13,6 +13,8 @@ pub enum CredentialKind {
|
||||||
Otp,
|
Otp,
|
||||||
#[serde(rename = "sso")]
|
#[serde(rename = "sso")]
|
||||||
Sso,
|
Sso,
|
||||||
|
#[serde(rename = "web")]
|
||||||
|
WebUserApproval,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
@ -27,6 +29,7 @@ pub enum AuthCredential {
|
||||||
provider: String,
|
provider: String,
|
||||||
email: String,
|
email: String,
|
||||||
},
|
},
|
||||||
|
WebUserApproval,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AuthCredential {
|
impl AuthCredential {
|
||||||
|
@ -36,6 +39,7 @@ impl AuthCredential {
|
||||||
Self::PublicKey { .. } => CredentialKind::PublicKey,
|
Self::PublicKey { .. } => CredentialKind::PublicKey,
|
||||||
Self::Otp { .. } => CredentialKind::Otp,
|
Self::Otp { .. } => CredentialKind::Otp,
|
||||||
Self::Sso { .. } => CredentialKind::Sso,
|
Self::Sso { .. } => CredentialKind::Sso,
|
||||||
|
Self::WebUserApproval => CredentialKind::WebUserApproval,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
use std::collections::HashSet;
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
use super::{AuthCredential, CredentialKind};
|
use super::{AuthCredential, CredentialKind};
|
||||||
use crate::UserRequireCredentialsPolicy;
|
|
||||||
|
|
||||||
pub enum CredentialPolicyResponse {
|
pub enum CredentialPolicyResponse {
|
||||||
Ok,
|
Ok,
|
||||||
NeedMoreCredentials,
|
Need(HashSet<CredentialKind>),
|
||||||
Need(CredentialKind),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait CredentialPolicy {
|
pub trait CredentialPolicy {
|
||||||
|
@ -17,36 +15,71 @@ pub trait CredentialPolicy {
|
||||||
) -> CredentialPolicyResponse;
|
) -> CredentialPolicyResponse;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CredentialPolicy for UserRequireCredentialsPolicy {
|
pub struct AnySingleCredentialPolicy {
|
||||||
|
pub supported_credential_types: HashSet<CredentialKind>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AllCredentialsPolicy {
|
||||||
|
pub required_credential_types: HashSet<CredentialKind>,
|
||||||
|
pub supported_credential_types: HashSet<CredentialKind>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PerProtocolCredentialPolicy {
|
||||||
|
pub protocols: HashMap<&'static str, Box<dyn CredentialPolicy + Send + Sync>>,
|
||||||
|
pub default: Box<dyn CredentialPolicy + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CredentialPolicy for AnySingleCredentialPolicy {
|
||||||
fn is_sufficient(
|
fn is_sufficient(
|
||||||
&self,
|
&self,
|
||||||
protocol: &str,
|
_protocol: &str,
|
||||||
valid_credentials: &[AuthCredential],
|
valid_credentials: &[AuthCredential],
|
||||||
) -> CredentialPolicyResponse {
|
) -> CredentialPolicyResponse {
|
||||||
let required_kinds = match protocol {
|
if valid_credentials.is_empty() {
|
||||||
"SSH" => &self.ssh,
|
CredentialPolicyResponse::Need(
|
||||||
"HTTP" => &self.http,
|
self.supported_credential_types
|
||||||
"MySQL" => &self.mysql,
|
.clone()
|
||||||
_ => unreachable!(),
|
.into_iter()
|
||||||
};
|
.collect(),
|
||||||
if let Some(required_kinds) = required_kinds {
|
)
|
||||||
let mut remaining_required_kinds = HashSet::<CredentialKind>::new();
|
|
||||||
remaining_required_kinds.extend(required_kinds);
|
|
||||||
for kind in required_kinds {
|
|
||||||
if valid_credentials.iter().any(|x| x.kind() == *kind) {
|
|
||||||
remaining_required_kinds.remove(kind);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(kind) = remaining_required_kinds.into_iter().next() {
|
|
||||||
CredentialPolicyResponse::Need(kind)
|
|
||||||
} else {
|
|
||||||
CredentialPolicyResponse::Ok
|
|
||||||
}
|
|
||||||
} else if valid_credentials.is_empty() {
|
|
||||||
CredentialPolicyResponse::NeedMoreCredentials
|
|
||||||
} else {
|
} else {
|
||||||
CredentialPolicyResponse::Ok
|
CredentialPolicyResponse::Ok
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CredentialPolicy for AllCredentialsPolicy {
|
||||||
|
fn is_sufficient(
|
||||||
|
&self,
|
||||||
|
_protocol: &str,
|
||||||
|
valid_credentials: &[AuthCredential],
|
||||||
|
) -> CredentialPolicyResponse {
|
||||||
|
let valid_credential_types: HashSet<CredentialKind> =
|
||||||
|
valid_credentials.iter().map(|x| x.kind()).collect();
|
||||||
|
|
||||||
|
if valid_credential_types.is_superset(&self.required_credential_types) {
|
||||||
|
CredentialPolicyResponse::Ok
|
||||||
|
} else {
|
||||||
|
CredentialPolicyResponse::Need(
|
||||||
|
self.required_credential_types
|
||||||
|
.difference(&valid_credential_types)
|
||||||
|
.cloned()
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CredentialPolicy for PerProtocolCredentialPolicy {
|
||||||
|
fn is_sufficient(
|
||||||
|
&self,
|
||||||
|
protocol: &str,
|
||||||
|
valid_credentials: &[AuthCredential],
|
||||||
|
) -> CredentialPolicyResponse {
|
||||||
|
if let Some(policy) = self.protocols.get(protocol) {
|
||||||
|
policy.is_sufficient(protocol, valid_credentials)
|
||||||
|
} else {
|
||||||
|
self.default.is_sufficient(protocol, valid_credentials)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,71 +1,66 @@
|
||||||
use std::time::{Duration, Instant};
|
use uuid::Uuid;
|
||||||
|
|
||||||
use once_cell::sync::Lazy;
|
|
||||||
use tracing::warn;
|
|
||||||
|
|
||||||
use super::{AuthCredential, CredentialPolicy, CredentialPolicyResponse};
|
use super::{AuthCredential, CredentialPolicy, CredentialPolicyResponse};
|
||||||
use crate::AuthResult;
|
use crate::AuthResult;
|
||||||
|
|
||||||
#[allow(clippy::unwrap_used)]
|
|
||||||
pub static TIMEOUT: Lazy<Duration> = Lazy::new(|| Duration::from_secs(60 * 10));
|
|
||||||
|
|
||||||
pub struct AuthState {
|
pub struct AuthState {
|
||||||
|
id: Uuid,
|
||||||
username: String,
|
username: String,
|
||||||
protocol: String,
|
protocol: String,
|
||||||
policy: Option<Box<dyn CredentialPolicy + Sync + Send>>,
|
force_rejected: bool,
|
||||||
|
policy: Box<dyn CredentialPolicy + Sync + Send>,
|
||||||
valid_credentials: Vec<AuthCredential>,
|
valid_credentials: Vec<AuthCredential>,
|
||||||
started_at: Instant,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AuthState {
|
impl AuthState {
|
||||||
pub fn new(
|
pub(crate) fn new(
|
||||||
|
id: Uuid,
|
||||||
username: String,
|
username: String,
|
||||||
protocol: String,
|
protocol: String,
|
||||||
policy: Option<Box<dyn CredentialPolicy + Sync + Send>>,
|
policy: Box<dyn CredentialPolicy + Sync + Send>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
id,
|
||||||
username,
|
username,
|
||||||
protocol,
|
protocol,
|
||||||
|
force_rejected: false,
|
||||||
policy,
|
policy,
|
||||||
valid_credentials: vec![],
|
valid_credentials: vec![],
|
||||||
started_at: Instant::now(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> &Uuid {
|
||||||
|
&self.id
|
||||||
|
}
|
||||||
|
|
||||||
pub fn username(&self) -> &str {
|
pub fn username(&self) -> &str {
|
||||||
&self.username
|
&self.username
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn protocol(&self) -> &str {
|
||||||
|
&self.protocol
|
||||||
|
}
|
||||||
|
|
||||||
pub fn add_valid_credential(&mut self, credential: AuthCredential) {
|
pub fn add_valid_credential(&mut self, credential: AuthCredential) {
|
||||||
self.valid_credentials.push(credential);
|
self.valid_credentials.push(credential);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_expired(&self) -> bool {
|
pub fn reject(&mut self) {
|
||||||
self.started_at.elapsed() > *TIMEOUT
|
self.force_rejected = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn verify(&self) -> AuthResult {
|
pub fn verify(&self) -> AuthResult {
|
||||||
if self.valid_credentials.is_empty() {
|
if self.force_rejected {
|
||||||
warn!(
|
|
||||||
username=%self.username,
|
|
||||||
"No matching valid credentials"
|
|
||||||
);
|
|
||||||
return AuthResult::Rejected;
|
return AuthResult::Rejected;
|
||||||
}
|
}
|
||||||
|
match self
|
||||||
if let Some(ref policy) = self.policy {
|
.policy
|
||||||
match policy.is_sufficient(&self.protocol, &self.valid_credentials[..]) {
|
.is_sufficient(&self.protocol, &self.valid_credentials[..])
|
||||||
CredentialPolicyResponse::Ok => {}
|
{
|
||||||
CredentialPolicyResponse::Need(kind) => {
|
CredentialPolicyResponse::Ok => AuthResult::Accepted {
|
||||||
return AuthResult::Need(kind);
|
username: self.username.clone(),
|
||||||
}
|
},
|
||||||
CredentialPolicyResponse::NeedMoreCredentials => {
|
CredentialPolicyResponse::Need(kinds) => AuthResult::Need(kinds),
|
||||||
return AuthResult::Rejected;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
AuthResult::Accepted {
|
|
||||||
username: self.username.clone(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,32 @@
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use tokio::sync::Mutex;
|
use once_cell::sync::Lazy;
|
||||||
|
use tokio::sync::{broadcast, Mutex};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use super::AuthState;
|
use super::AuthState;
|
||||||
use crate::{ConfigProvider, WarpgateError};
|
use crate::{AuthResult, ConfigProvider, WarpgateError};
|
||||||
|
|
||||||
|
#[allow(clippy::unwrap_used)]
|
||||||
|
pub static TIMEOUT: Lazy<Duration> = Lazy::new(|| Duration::from_secs(60 * 10));
|
||||||
|
|
||||||
|
struct AuthCompletionSignal {
|
||||||
|
sender: broadcast::Sender<AuthResult>,
|
||||||
|
created_at: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthCompletionSignal {
|
||||||
|
pub fn is_expired(&self) -> bool {
|
||||||
|
self.created_at.elapsed() > *TIMEOUT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct AuthStateStore {
|
pub struct AuthStateStore {
|
||||||
config_provider: Arc<Mutex<dyn ConfigProvider + Send + 'static>>,
|
config_provider: Arc<Mutex<dyn ConfigProvider + Send + 'static>>,
|
||||||
store: HashMap<Uuid, AuthState>,
|
store: HashMap<Uuid, (Arc<Mutex<AuthState>>, Instant)>,
|
||||||
|
completion_signals: HashMap<Uuid, AuthCompletionSignal>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AuthStateStore {
|
impl AuthStateStore {
|
||||||
|
@ -17,47 +34,66 @@ impl AuthStateStore {
|
||||||
Self {
|
Self {
|
||||||
store: HashMap::new(),
|
store: HashMap::new(),
|
||||||
config_provider,
|
config_provider,
|
||||||
|
completion_signals: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn contains_key(&mut self, id: &Uuid) -> bool {
|
pub fn contains_key(&self, id: &Uuid) -> bool {
|
||||||
self.store.contains_key(id)
|
self.store.contains_key(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_mut(&mut self, id: &Uuid) -> Option<&mut AuthState> {
|
pub fn get(&self, id: &Uuid) -> Option<Arc<Mutex<AuthState>>> {
|
||||||
self.store.get_mut(id)
|
self.store.get(id).map(|x| x.0.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create(
|
pub async fn create(
|
||||||
&mut self,
|
&mut self,
|
||||||
username: &str,
|
username: &str,
|
||||||
protocol: &str,
|
protocol: &str,
|
||||||
) -> Result<(Uuid, &mut AuthState), WarpgateError> {
|
) -> Result<(Uuid, Arc<Mutex<AuthState>>), WarpgateError> {
|
||||||
let id = Uuid::new_v4();
|
let id = Uuid::new_v4();
|
||||||
let state = AuthState::new(
|
let Some(policy) = self.config_provider
|
||||||
username.to_string(),
|
.lock()
|
||||||
protocol.to_string(),
|
.await
|
||||||
self.config_provider
|
.get_credential_policy(username)
|
||||||
.lock()
|
.await? else {
|
||||||
.await
|
return Err(WarpgateError::UserNotFound)
|
||||||
.get_credential_policy(username)
|
};
|
||||||
.await?,
|
|
||||||
);
|
let state = AuthState::new(id, username.to_string(), protocol.to_string(), policy);
|
||||||
self.store.insert(id, state);
|
self.store
|
||||||
|
.insert(id, (Arc::new(Mutex::new(state)), Instant::now()));
|
||||||
|
|
||||||
#[allow(clippy::unwrap_used)]
|
#[allow(clippy::unwrap_used)]
|
||||||
Ok((id, self.store.get_mut(&id).unwrap()))
|
Ok((id, self.get(&id).unwrap()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn subscribe(&mut self, id: Uuid) -> broadcast::Receiver<AuthResult> {
|
||||||
|
let signal = self.completion_signals.entry(id).or_insert_with(|| {
|
||||||
|
let (sender, _) = broadcast::channel(1);
|
||||||
|
AuthCompletionSignal {
|
||||||
|
sender,
|
||||||
|
created_at: Instant::now(),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
signal.sender.subscribe()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn complete(&mut self, id: &Uuid) {
|
||||||
|
let Some((state, _)) = self.store.get(id) else {
|
||||||
|
return
|
||||||
|
};
|
||||||
|
if let Some(sig) = self.completion_signals.remove(id) {
|
||||||
|
let _ = sig.sender.send(state.lock().await.verify());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn vacuum(&mut self) {
|
pub async fn vacuum(&mut self) {
|
||||||
let mut to_remove = vec![];
|
self.store
|
||||||
for (id, state) in self.store.iter() {
|
.retain(|_, (_, started_at)| started_at.elapsed() < *TIMEOUT);
|
||||||
if state.is_expired() {
|
|
||||||
to_remove.push(*id);
|
self.completion_signals
|
||||||
}
|
.retain(|_, signal| !signal.is_expired());
|
||||||
}
|
|
||||||
for id in to_remove {
|
|
||||||
self.store.remove(&id);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,11 +5,12 @@ use std::time::Duration;
|
||||||
|
|
||||||
use poem_openapi::{Enum, Object, Union};
|
use poem_openapi::{Enum, Object, Union};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use url::Url;
|
||||||
use warpgate_sso::SsoProviderConfig;
|
use warpgate_sso::SsoProviderConfig;
|
||||||
|
|
||||||
use crate::auth::CredentialKind;
|
use crate::auth::CredentialKind;
|
||||||
use crate::helpers::otp::OtpSecretKey;
|
use crate::helpers::otp::OtpSecretKey;
|
||||||
use crate::{ListenEndpoint, Secret};
|
use crate::{ListenEndpoint, Secret, WarpgateError};
|
||||||
|
|
||||||
const fn _default_true() -> bool {
|
const fn _default_true() -> bool {
|
||||||
true
|
true
|
||||||
|
@ -426,3 +427,24 @@ pub struct WarpgateConfig {
|
||||||
pub store: WarpgateConfigStore,
|
pub store: WarpgateConfigStore,
|
||||||
pub paths_relative_to: PathBuf,
|
pub paths_relative_to: PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl WarpgateConfig {
|
||||||
|
pub fn construct_external_url(
|
||||||
|
&self,
|
||||||
|
fallback_host: Option<&str>,
|
||||||
|
) -> Result<Url, WarpgateError> {
|
||||||
|
let ext_host = self.store.external_host.as_deref().or(fallback_host);
|
||||||
|
let Some(ext_host) = ext_host else {
|
||||||
|
return Err(WarpgateError::ExternalHostNotSet);
|
||||||
|
};
|
||||||
|
let ext_port = self.store.http.listen.port();
|
||||||
|
|
||||||
|
let mut url = Url::parse(&format!("https://{ext_host}/"))?;
|
||||||
|
|
||||||
|
if ext_port != 443 {
|
||||||
|
let _ = url.set_port(Some(ext_port));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::collections::HashSet;
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
@ -11,7 +11,10 @@ use uuid::Uuid;
|
||||||
use warpgate_db_entities::Ticket;
|
use warpgate_db_entities::Ticket;
|
||||||
|
|
||||||
use super::ConfigProvider;
|
use super::ConfigProvider;
|
||||||
use crate::auth::{AuthCredential, CredentialPolicy};
|
use crate::auth::{
|
||||||
|
AllCredentialsPolicy, AnySingleCredentialPolicy, AuthCredential, CredentialKind,
|
||||||
|
CredentialPolicy, PerProtocolCredentialPolicy,
|
||||||
|
};
|
||||||
use crate::helpers::hash::verify_password_hash;
|
use crate::helpers::hash::verify_password_hash;
|
||||||
use crate::helpers::otp::verify_totp;
|
use crate::helpers::otp::verify_totp;
|
||||||
use crate::{Target, User, UserAuthCredential, UserSnapshot, WarpgateConfig, WarpgateError};
|
use crate::{Target, User, UserAuthCredential, UserSnapshot, WarpgateConfig, WarpgateError};
|
||||||
|
@ -78,9 +81,52 @@ impl ConfigProvider for FileConfigProvider {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(user
|
let supported_credential_types: HashSet<CredentialKind> =
|
||||||
.require
|
user.credentials.iter().map(|x| x.kind()).collect();
|
||||||
.map(|r| Box::new(r) as Box<dyn CredentialPolicy + Sync + Send>))
|
let default_policy = Box::new(AnySingleCredentialPolicy {
|
||||||
|
supported_credential_types: supported_credential_types.clone(),
|
||||||
|
}) as Box<dyn CredentialPolicy + Sync + Send>;
|
||||||
|
|
||||||
|
if let Some(req) = user.require {
|
||||||
|
let mut policy = PerProtocolCredentialPolicy {
|
||||||
|
default: default_policy,
|
||||||
|
protocols: HashMap::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(p) = req.http {
|
||||||
|
policy.protocols.insert(
|
||||||
|
"HTTP",
|
||||||
|
Box::new(AllCredentialsPolicy {
|
||||||
|
supported_credential_types: supported_credential_types.clone(),
|
||||||
|
required_credential_types: p.into_iter().collect(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if let Some(p) = req.mysql {
|
||||||
|
policy.protocols.insert(
|
||||||
|
"MySQL",
|
||||||
|
Box::new(AllCredentialsPolicy {
|
||||||
|
supported_credential_types: supported_credential_types.clone(),
|
||||||
|
required_credential_types: p.into_iter().collect(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if let Some(p) = req.ssh {
|
||||||
|
policy.protocols.insert(
|
||||||
|
"SSH",
|
||||||
|
Box::new(AllCredentialsPolicy {
|
||||||
|
supported_credential_types,
|
||||||
|
required_credential_types: p.into_iter().collect(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(
|
||||||
|
Box::new(policy) as Box<dyn CredentialPolicy + Sync + Send>
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(Some(default_policy))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn username_for_sso_credential(
|
async fn username_for_sso_credential(
|
||||||
|
@ -193,6 +239,7 @@ impl ConfigProvider for FileConfigProvider {
|
||||||
}
|
}
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
_ => return Err(WarpgateError::InvalidCredentialType),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
mod file;
|
mod file;
|
||||||
|
use std::collections::HashSet;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
@ -12,11 +13,10 @@ use warpgate_db_entities::Ticket;
|
||||||
use crate::auth::{AuthCredential, CredentialKind, CredentialPolicy};
|
use crate::auth::{AuthCredential, CredentialKind, CredentialPolicy};
|
||||||
use crate::{Secret, Target, UserSnapshot, WarpgateError};
|
use crate::{Secret, Target, UserSnapshot, WarpgateError};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum AuthResult {
|
pub enum AuthResult {
|
||||||
Accepted { username: String },
|
Accepted { username: String },
|
||||||
Need(CredentialKind),
|
Need(HashSet<CredentialKind>),
|
||||||
NeedMoreCredentials,
|
|
||||||
Rejected,
|
Rejected,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,16 @@ pub enum WarpgateError {
|
||||||
DatabaseError(#[from] sea_orm::DbErr),
|
DatabaseError(#[from] sea_orm::DbErr),
|
||||||
#[error("ticket not found: {0}")]
|
#[error("ticket not found: {0}")]
|
||||||
InvalidTicket(Uuid),
|
InvalidTicket(Uuid),
|
||||||
|
#[error("invalid credential type")]
|
||||||
|
InvalidCredentialType,
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Other(Box<dyn Error + Send + Sync>),
|
Other(Box<dyn Error + Send + Sync>),
|
||||||
|
#[error("user not found")]
|
||||||
|
UserNotFound,
|
||||||
|
#[error("failed to parse url: {0}")]
|
||||||
|
UrlParse(#[from] url::ParseError),
|
||||||
|
#[error("external_url config option is not set")]
|
||||||
|
ExternalHostNotSet,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResponseError for WarpgateError {
|
impl ResponseError for WarpgateError {
|
||||||
|
|
|
@ -7,6 +7,7 @@ version = "0.4.0"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
|
chrono = {version = "0.4", features = ["serde"]}
|
||||||
cookie = "0.16"
|
cookie = "0.16"
|
||||||
data-encoding = "2.3"
|
data-encoding = "2.3"
|
||||||
delegate = "0.6"
|
delegate = "0.6"
|
||||||
|
|
|
@ -3,14 +3,18 @@ use std::sync::Arc;
|
||||||
use poem::session::Session;
|
use poem::session::Session;
|
||||||
use poem::web::Data;
|
use poem::web::Data;
|
||||||
use poem::Request;
|
use poem::Request;
|
||||||
|
use poem_openapi::param::Path;
|
||||||
use poem_openapi::payload::Json;
|
use poem_openapi::payload::Json;
|
||||||
use poem_openapi::{ApiResponse, Enum, Object, OpenApi};
|
use poem_openapi::{ApiResponse, Enum, Object, OpenApi};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tracing::*;
|
use tracing::*;
|
||||||
use warpgate_common::auth::{AuthCredential, CredentialKind};
|
use uuid::Uuid;
|
||||||
|
use warpgate_common::auth::{AuthCredential, AuthState, CredentialKind};
|
||||||
use warpgate_common::{AuthResult, Secret, Services};
|
use warpgate_common::{AuthResult, Secret, Services};
|
||||||
|
|
||||||
use crate::common::{authorize_session, get_auth_state_for_request, SessionExt};
|
use crate::common::{
|
||||||
|
authorize_session, endpoint_auth, get_auth_state_for_request, SessionAuthorization, SessionExt,
|
||||||
|
};
|
||||||
use crate::session::SessionStore;
|
use crate::session::SessionStore;
|
||||||
|
|
||||||
pub struct Api;
|
pub struct Api;
|
||||||
|
@ -33,6 +37,8 @@ enum ApiAuthState {
|
||||||
PasswordNeeded,
|
PasswordNeeded,
|
||||||
OtpNeeded,
|
OtpNeeded,
|
||||||
SsoNeeded,
|
SsoNeeded,
|
||||||
|
WebUserApprovalNeeded,
|
||||||
|
PublicKeyNeeded,
|
||||||
Success,
|
Success,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,6 +64,7 @@ enum LogoutResponse {
|
||||||
|
|
||||||
#[derive(Object)]
|
#[derive(Object)]
|
||||||
struct AuthStateResponseInternal {
|
struct AuthStateResponseInternal {
|
||||||
|
pub protocol: String,
|
||||||
pub state: ApiAuthState,
|
pub state: ApiAuthState,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,17 +72,22 @@ struct AuthStateResponseInternal {
|
||||||
enum AuthStateResponse {
|
enum AuthStateResponse {
|
||||||
#[oai(status = 200)]
|
#[oai(status = 200)]
|
||||||
Ok(Json<AuthStateResponseInternal>),
|
Ok(Json<AuthStateResponseInternal>),
|
||||||
|
#[oai(status = 404)]
|
||||||
|
NotFound,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<AuthResult> for ApiAuthState {
|
impl From<AuthResult> for ApiAuthState {
|
||||||
fn from(state: AuthResult) -> Self {
|
fn from(state: AuthResult) -> Self {
|
||||||
match state {
|
match state {
|
||||||
AuthResult::Rejected => ApiAuthState::Failed,
|
AuthResult::Rejected => ApiAuthState::Failed,
|
||||||
AuthResult::Need(CredentialKind::Password) => ApiAuthState::PasswordNeeded,
|
AuthResult::Need(kinds) => match kinds.iter().next() {
|
||||||
AuthResult::Need(CredentialKind::Otp) => ApiAuthState::OtpNeeded,
|
Some(CredentialKind::Password) => ApiAuthState::PasswordNeeded,
|
||||||
AuthResult::Need(CredentialKind::Sso) => ApiAuthState::SsoNeeded,
|
Some(CredentialKind::Otp) => ApiAuthState::OtpNeeded,
|
||||||
AuthResult::Need(CredentialKind::PublicKey) => ApiAuthState::Failed,
|
Some(CredentialKind::Sso) => ApiAuthState::SsoNeeded,
|
||||||
AuthResult::NeedMoreCredentials => ApiAuthState::Failed,
|
Some(CredentialKind::WebUserApproval) => ApiAuthState::WebUserApprovalNeeded,
|
||||||
|
Some(CredentialKind::PublicKey) => ApiAuthState::PublicKeyNeeded,
|
||||||
|
None => ApiAuthState::Failed,
|
||||||
|
},
|
||||||
AuthResult::Accepted { .. } => ApiAuthState::Success,
|
AuthResult::Accepted { .. } => ApiAuthState::Success,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -92,8 +104,9 @@ impl Api {
|
||||||
body: Json<LoginRequest>,
|
body: Json<LoginRequest>,
|
||||||
) -> poem::Result<LoginResponse> {
|
) -> poem::Result<LoginResponse> {
|
||||||
let mut auth_state_store = services.auth_state_store.lock().await;
|
let mut auth_state_store = services.auth_state_store.lock().await;
|
||||||
let state =
|
let state_arc =
|
||||||
get_auth_state_for_request(&body.username, session, &mut auth_state_store).await?;
|
get_auth_state_for_request(&body.username, session, &mut auth_state_store).await?;
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
|
||||||
let mut cp = services.config_provider.lock().await;
|
let mut cp = services.config_provider.lock().await;
|
||||||
|
|
||||||
|
@ -107,6 +120,7 @@ impl Api {
|
||||||
|
|
||||||
match state.verify() {
|
match state.verify() {
|
||||||
AuthResult::Accepted { username } => {
|
AuthResult::Accepted { username } => {
|
||||||
|
auth_state_store.complete(state.id()).await;
|
||||||
authorize_session(req, username).await?;
|
authorize_session(req, username).await?;
|
||||||
Ok(LoginResponse::Success)
|
Ok(LoginResponse::Success)
|
||||||
}
|
}
|
||||||
|
@ -131,12 +145,14 @@ impl Api {
|
||||||
|
|
||||||
let mut auth_state_store = services.auth_state_store.lock().await;
|
let mut auth_state_store = services.auth_state_store.lock().await;
|
||||||
|
|
||||||
let Some(state) = state_id.and_then(|id| auth_state_store.get_mut(&id.0)) else {
|
let Some(state_arc) = state_id.and_then(|id| auth_state_store.get(&id.0)) else {
|
||||||
return Ok(LoginResponse::Failure(Json(LoginFailureResponse {
|
return Ok(LoginResponse::Failure(Json(LoginFailureResponse {
|
||||||
state: ApiAuthState::NotStarted,
|
state: ApiAuthState::NotStarted,
|
||||||
})))
|
})))
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
|
||||||
let mut cp = services.config_provider.lock().await;
|
let mut cp = services.config_provider.lock().await;
|
||||||
|
|
||||||
let otp_cred = AuthCredential::Otp(body.otp.clone().into());
|
let otp_cred = AuthCredential::Otp(body.otp.clone().into());
|
||||||
|
@ -146,6 +162,7 @@ impl Api {
|
||||||
|
|
||||||
match state.verify() {
|
match state.verify() {
|
||||||
AuthResult::Accepted { username } => {
|
AuthResult::Accepted { username } => {
|
||||||
|
auth_state_store.complete(state.id()).await;
|
||||||
authorize_session(req, username).await?;
|
authorize_session(req, username).await?;
|
||||||
Ok(LoginResponse::Success)
|
Ok(LoginResponse::Success)
|
||||||
}
|
}
|
||||||
|
@ -155,27 +172,6 @@ impl Api {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[oai(path = "/auth/state", method = "get", operation_id = "getAuthState")]
|
|
||||||
async fn api_auth_state(
|
|
||||||
&self,
|
|
||||||
session: &Session,
|
|
||||||
services: Data<&Services>,
|
|
||||||
) -> poem::Result<AuthStateResponse> {
|
|
||||||
let state_id = session.get_auth_state_id();
|
|
||||||
|
|
||||||
let mut auth_state_store = services.auth_state_store.lock().await;
|
|
||||||
|
|
||||||
let Some(state) = state_id.and_then(|id| auth_state_store.get_mut(&id.0)) else {
|
|
||||||
return Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal {
|
|
||||||
state: ApiAuthState::NotStarted,
|
|
||||||
})));
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal {
|
|
||||||
state: state.verify().into(),
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[oai(path = "/auth/logout", method = "post", operation_id = "logout")]
|
#[oai(path = "/auth/logout", method = "post", operation_id = "logout")]
|
||||||
async fn api_auth_logout(
|
async fn api_auth_logout(
|
||||||
&self,
|
&self,
|
||||||
|
@ -187,4 +183,129 @@ impl Api {
|
||||||
info!("Logged out");
|
info!("Logged out");
|
||||||
Ok(LogoutResponse::Success)
|
Ok(LogoutResponse::Success)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[oai(
|
||||||
|
path = "/auth/state",
|
||||||
|
method = "get",
|
||||||
|
operation_id = "getDefaultAuthState"
|
||||||
|
)]
|
||||||
|
async fn api_default_auth_state(
|
||||||
|
&self,
|
||||||
|
session: &Session,
|
||||||
|
services: Data<&Services>,
|
||||||
|
) -> poem::Result<AuthStateResponse> {
|
||||||
|
let Some(state_id) = session.get_auth_state_id() else {
|
||||||
|
return Ok(AuthStateResponse::NotFound)
|
||||||
|
};
|
||||||
|
let store = services.auth_state_store.lock().await;
|
||||||
|
let Some(state_arc) = store.get(&state_id.0) else {
|
||||||
|
return Ok(AuthStateResponse::NotFound);
|
||||||
|
};
|
||||||
|
serialize_auth_state_inner(state_arc).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[oai(
|
||||||
|
path = "/auth/state/:id",
|
||||||
|
method = "get",
|
||||||
|
operation_id = "get_auth_state",
|
||||||
|
transform = "endpoint_auth"
|
||||||
|
)]
|
||||||
|
async fn api_auth_state(
|
||||||
|
&self,
|
||||||
|
services: Data<&Services>,
|
||||||
|
auth: Option<Data<&SessionAuthorization>>,
|
||||||
|
id: Path<Uuid>,
|
||||||
|
) -> poem::Result<AuthStateResponse> {
|
||||||
|
let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else {
|
||||||
|
return Ok(AuthStateResponse::NotFound);
|
||||||
|
};
|
||||||
|
serialize_auth_state_inner(state_arc).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[oai(
|
||||||
|
path = "/auth/state/:id/approve",
|
||||||
|
method = "post",
|
||||||
|
operation_id = "approve_auth",
|
||||||
|
transform = "endpoint_auth"
|
||||||
|
)]
|
||||||
|
async fn api_approve_auth(
|
||||||
|
&self,
|
||||||
|
services: Data<&Services>,
|
||||||
|
auth: Option<Data<&SessionAuthorization>>,
|
||||||
|
id: Path<Uuid>,
|
||||||
|
) -> poem::Result<AuthStateResponse> {
|
||||||
|
let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else {
|
||||||
|
return Ok(AuthStateResponse::NotFound);
|
||||||
|
};
|
||||||
|
|
||||||
|
let auth_result = {
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
state.add_valid_credential(AuthCredential::WebUserApproval);
|
||||||
|
state.verify()
|
||||||
|
};
|
||||||
|
|
||||||
|
if let AuthResult::Accepted { .. } = auth_result {
|
||||||
|
services.auth_state_store.lock().await.complete(&*id).await;
|
||||||
|
}
|
||||||
|
serialize_auth_state_inner(state_arc).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[oai(
|
||||||
|
path = "/auth/state/:id/reject",
|
||||||
|
method = "post",
|
||||||
|
operation_id = "reject_auth",
|
||||||
|
transform = "endpoint_auth"
|
||||||
|
)]
|
||||||
|
async fn api_reject_auth(
|
||||||
|
&self,
|
||||||
|
services: Data<&Services>,
|
||||||
|
auth: Option<Data<&SessionAuthorization>>,
|
||||||
|
id: Path<Uuid>,
|
||||||
|
) -> poem::Result<AuthStateResponse> {
|
||||||
|
let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else {
|
||||||
|
return Ok(AuthStateResponse::NotFound);
|
||||||
|
};
|
||||||
|
state_arc.lock().await.reject();
|
||||||
|
services.auth_state_store.lock().await.complete(&*id).await;
|
||||||
|
serialize_auth_state_inner(state_arc).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_auth_state(
|
||||||
|
id: &Uuid,
|
||||||
|
services: &Services,
|
||||||
|
auth: Option<&SessionAuthorization>,
|
||||||
|
) -> Option<Arc<Mutex<AuthState>>> {
|
||||||
|
let store = services.auth_state_store.lock().await;
|
||||||
|
|
||||||
|
let Some(auth) = auth else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let SessionAuthorization::User(username) = auth else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(state_arc) = store.get(&*id) else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
let state = state_arc.lock().await;
|
||||||
|
if state.username() != username {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(state_arc)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn serialize_auth_state_inner(
|
||||||
|
state_arc: Arc<Mutex<AuthState>>,
|
||||||
|
) -> poem::Result<AuthStateResponse> {
|
||||||
|
let state = state_arc.lock().await;
|
||||||
|
Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal {
|
||||||
|
protocol: state.protocol().to_string(),
|
||||||
|
state: state.verify().into(),
|
||||||
|
})))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
use poem::session::Session;
|
use poem::session::Session;
|
||||||
use poem::web::Data;
|
use poem::web::Data;
|
||||||
use poem::Request;
|
use poem::Request;
|
||||||
use poem_openapi::param::Path;
|
use poem_openapi::param::{Path, Query};
|
||||||
use poem_openapi::payload::Json;
|
use poem_openapi::payload::Json;
|
||||||
use poem_openapi::{ApiResponse, Object, OpenApi};
|
use poem_openapi::{ApiResponse, Object, OpenApi};
|
||||||
use reqwest::Url;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use warpgate_common::Services;
|
use warpgate_common::Services;
|
||||||
use warpgate_sso::{SsoClient, SsoLoginRequest};
|
use warpgate_sso::{SsoClient, SsoLoginRequest};
|
||||||
|
@ -31,6 +30,7 @@ pub static SSO_CONTEXT_SESSION_KEY: &str = "sso_request";
|
||||||
pub struct SsoContext {
|
pub struct SsoContext {
|
||||||
pub provider: String,
|
pub provider: String,
|
||||||
pub request: SsoLoginRequest,
|
pub request: SsoLoginRequest,
|
||||||
|
pub next_url: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[OpenApi]
|
#[OpenApi]
|
||||||
|
@ -46,31 +46,14 @@ impl Api {
|
||||||
session: &Session,
|
session: &Session,
|
||||||
services: Data<&Services>,
|
services: Data<&Services>,
|
||||||
name: Path<String>,
|
name: Path<String>,
|
||||||
|
next: Query<Option<String>>,
|
||||||
) -> poem::Result<StartSsoResponse> {
|
) -> poem::Result<StartSsoResponse> {
|
||||||
let config = services.config.lock().await;
|
let config = services.config.lock().await;
|
||||||
|
|
||||||
let name = name.0;
|
let name = name.0;
|
||||||
let ext_host = config
|
|
||||||
.store
|
|
||||||
.external_host
|
|
||||||
.as_deref()
|
|
||||||
.or_else(|| req.original_uri().host());
|
|
||||||
let Some(ext_host) = ext_host else {
|
|
||||||
return Err(poem::Error::from_string("external_host config option is required for SSO", http::status::StatusCode::INTERNAL_SERVER_ERROR));
|
|
||||||
};
|
|
||||||
let ext_port = config.store.http.listen.port();
|
|
||||||
|
|
||||||
let mut return_url = Url::parse(&format!("https://{ext_host}/@warpgate/api/sso/return"))
|
let mut return_url = config.construct_external_url(req.original_uri().host())?;
|
||||||
.map_err(|e| {
|
return_url.set_path("@warpgate/api/sso/return");
|
||||||
poem::Error::from_string(
|
|
||||||
format!("failed to construct the return URL: {e}"),
|
|
||||||
http::status::StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
if ext_port != 443 {
|
|
||||||
let _ = return_url.set_port(Some(ext_port));
|
|
||||||
}
|
|
||||||
|
|
||||||
let Some(provider_config) = config.store.sso_providers.iter().find(|p| p.name == *name) else {
|
let Some(provider_config) = config.store.sso_providers.iter().find(|p| p.name == *name) else {
|
||||||
return Ok(StartSsoResponse::NotFound);
|
return Ok(StartSsoResponse::NotFound);
|
||||||
|
@ -84,10 +67,14 @@ impl Api {
|
||||||
.map_err(poem::error::InternalServerError)?;
|
.map_err(poem::error::InternalServerError)?;
|
||||||
|
|
||||||
let url = sso_req.auth_url().to_string();
|
let url = sso_req.auth_url().to_string();
|
||||||
session.set(SSO_CONTEXT_SESSION_KEY, SsoContext {
|
session.set(
|
||||||
provider: name,
|
SSO_CONTEXT_SESSION_KEY,
|
||||||
request: sso_req,
|
SsoContext {
|
||||||
});
|
provider: name,
|
||||||
|
request: sso_req,
|
||||||
|
next_url: next.0.clone(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
Ok(StartSsoResponse::Ok(Json(StartSsoResponseParams { url })))
|
Ok(StartSsoResponse::Ok(Json(StartSsoResponseParams { url })))
|
||||||
}
|
}
|
||||||
|
|
|
@ -120,8 +120,9 @@ impl Api {
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut auth_state_store = services.auth_state_store.lock().await;
|
let mut auth_state_store = services.auth_state_store.lock().await;
|
||||||
let state = get_auth_state_for_request(&username, session, &mut auth_state_store).await?;
|
let state_arc = get_auth_state_for_request(&username, session, &mut auth_state_store).await?;
|
||||||
|
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
let mut cp = services.config_provider.lock().await;
|
let mut cp = services.config_provider.lock().await;
|
||||||
|
|
||||||
if cp.validate_credential(&username, &cred).await? {
|
if cp.validate_credential(&username, &cred).await? {
|
||||||
|
@ -130,11 +131,15 @@ impl Api {
|
||||||
|
|
||||||
match state.verify() {
|
match state.verify() {
|
||||||
AuthResult::Accepted { username } => {
|
AuthResult::Accepted { username } => {
|
||||||
|
auth_state_store.complete(state.id()).await;
|
||||||
authorize_session(req, username).await?;
|
authorize_session(req, username).await?;
|
||||||
}
|
}
|
||||||
_ => ()
|
_ => (),
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Response::new(ReturnToSsoResponse::Ok).header("Location", "/@warpgate"))
|
Ok(Response::new(ReturnToSsoResponse::Ok).header(
|
||||||
|
"Location",
|
||||||
|
context.next_url.as_deref().unwrap_or("/@warpgate#/login"),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,18 +170,18 @@ pub fn gateway_redirect(req: &Request) -> Response {
|
||||||
.unwrap_or("".into());
|
.unwrap_or("".into());
|
||||||
|
|
||||||
let path = format!(
|
let path = format!(
|
||||||
"/@warpgate?next={}",
|
"/@warpgate#/login?next={}",
|
||||||
utf8_percent_encode(&path, NON_ALPHANUMERIC),
|
utf8_percent_encode(&path, NON_ALPHANUMERIC),
|
||||||
);
|
);
|
||||||
|
|
||||||
Redirect::temporary(path).into_response()
|
Redirect::temporary(path).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_auth_state_for_request<'a>(
|
pub async fn get_auth_state_for_request(
|
||||||
username: &str,
|
username: &str,
|
||||||
session: &Session,
|
session: &Session,
|
||||||
store: &'a mut AuthStateStore,
|
store: &mut AuthStateStore,
|
||||||
) -> Result<&'a mut AuthState, WarpgateError> {
|
) -> Result<Arc<Mutex<AuthState>>, WarpgateError> {
|
||||||
match session.get_auth_state_id() {
|
match session.get_auth_state_id() {
|
||||||
Some(id) => {
|
Some(id) => {
|
||||||
if !store.contains_key(&id.0) {
|
if !store.contains_key(&id.0) {
|
||||||
|
@ -192,7 +192,7 @@ pub async fn get_auth_state_for_request<'a>(
|
||||||
};
|
};
|
||||||
|
|
||||||
match session.get_auth_state_id() {
|
match session.get_auth_state_id() {
|
||||||
Some(id) => Ok(store.get_mut(&id.0).unwrap()),
|
Some(id) => Ok(store.get(&id.0).unwrap()),
|
||||||
None => {
|
None => {
|
||||||
let (id, state) = store
|
let (id, state) = store
|
||||||
.create(&username, crate::common::PROTOCOL_NAME)
|
.create(&username, crate::common::PROTOCOL_NAME)
|
||||||
|
|
|
@ -7,7 +7,7 @@ use tokio::net::TcpStream;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tracing::*;
|
use tracing::*;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState};
|
use warpgate_common::auth::{AuthCredential, AuthSelector};
|
||||||
use warpgate_common::helpers::rng::get_crypto_rng;
|
use warpgate_common::helpers::rng::get_crypto_rng;
|
||||||
use warpgate_common::{
|
use warpgate_common::{
|
||||||
authorize_ticket, AuthResult, Secret, Services, TargetMySqlOptions, TargetOptions,
|
authorize_ticket, AuthResult, Secret, Services, TargetMySqlOptions, TargetOptions,
|
||||||
|
@ -180,15 +180,20 @@ impl MySqlSession {
|
||||||
username,
|
username,
|
||||||
target_name,
|
target_name,
|
||||||
} => {
|
} => {
|
||||||
let user_auth_result = {
|
let state_arc = self
|
||||||
let mut cp = self.services.config_provider.lock().await;
|
.services
|
||||||
|
.auth_state_store
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.create(&username, crate::common::PROTOCOL_NAME)
|
||||||
|
.await?
|
||||||
|
.1;
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
|
||||||
|
let user_auth_result = {
|
||||||
let credential = AuthCredential::Password(password);
|
let credential = AuthCredential::Password(password);
|
||||||
let mut state = AuthState::new(
|
|
||||||
username.clone(),
|
let mut cp = self.services.config_provider.lock().await;
|
||||||
crate::common::PROTOCOL_NAME.to_string(),
|
|
||||||
cp.get_credential_policy(&username).await?,
|
|
||||||
);
|
|
||||||
if cp.validate_credential(&username, &credential).await? {
|
if cp.validate_credential(&username, &credential).await? {
|
||||||
state.add_valid_credential(credential);
|
state.add_valid_credential(credential);
|
||||||
}
|
}
|
||||||
|
@ -198,6 +203,12 @@ impl MySqlSession {
|
||||||
|
|
||||||
match user_auth_result {
|
match user_auth_result {
|
||||||
AuthResult::Accepted { username } => {
|
AuthResult::Accepted { username } => {
|
||||||
|
self.services
|
||||||
|
.auth_state_store
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.complete(state.id())
|
||||||
|
.await;
|
||||||
let target_auth_result = {
|
let target_auth_result = {
|
||||||
self.services
|
self.services
|
||||||
.config_provider
|
.config_provider
|
||||||
|
@ -216,9 +227,7 @@ impl MySqlSession {
|
||||||
}
|
}
|
||||||
self.run_authorized(handshake, username, target_name).await
|
self.run_authorized(handshake, username, target_name).await
|
||||||
}
|
}
|
||||||
AuthResult::Rejected
|
AuthResult::Rejected | AuthResult::Need(_) => fail(&mut self).await, // TODO SSO
|
||||||
| AuthResult::Need(_)
|
|
||||||
| AuthResult::NeedMoreCredentials => fail(&mut self).await, // TODO SSO
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AuthSelector::Ticket { secret } => {
|
AuthSelector::Ticket { secret } => {
|
||||||
|
|
|
@ -12,7 +12,7 @@ bimap = "0.6"
|
||||||
bytes = "1.2"
|
bytes = "1.2"
|
||||||
dialoguer = "0.10"
|
dialoguer = "0.10"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
russh = {version = "0.34.0-beta.7", features = ["openssl"]}
|
russh = {version = "0.34.0-beta.8", features = ["openssl"]}
|
||||||
russh-keys = {version = "0.22.0-beta.4", features = ["openssl"]}
|
russh-keys = {version = "0.22.0-beta.4", features = ["openssl"]}
|
||||||
sea-orm = {version = "^0.9", features = ["runtime-tokio-native-tls"], default-features = false}
|
sea-orm = {version = "^0.9", features = ["runtime-tokio-native-tls"], default-features = false}
|
||||||
thiserror = "1.0"
|
thiserror = "1.0"
|
||||||
|
|
|
@ -23,6 +23,7 @@ pub async fn run_server(services: Services, address: SocketAddr) -> Result<()> {
|
||||||
let config = services.config.lock().await;
|
let config = services.config.lock().await;
|
||||||
russh::server::Config {
|
russh::server::Config {
|
||||||
auth_rejection_time: std::time::Duration::from_secs(1),
|
auth_rejection_time: std::time::Duration::from_secs(1),
|
||||||
|
connection_timeout: Some(std::time::Duration::from_secs(300)),
|
||||||
methods: MethodSet::PUBLICKEY | MethodSet::PASSWORD | MethodSet::KEYBOARD_INTERACTIVE,
|
methods: MethodSet::PUBLICKEY | MethodSet::PASSWORD | MethodSet::KEYBOARD_INTERACTIVE,
|
||||||
keys: load_host_keys(&config)?,
|
keys: load_host_keys(&config)?,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::collections::hash_map::Entry::Vacant;
|
use std::collections::hash_map::Entry::Vacant;
|
||||||
use std::collections::HashMap;
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::net::{Ipv4Addr, SocketAddr};
|
use std::net::{Ipv4Addr, SocketAddr};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -10,11 +10,11 @@ use anyhow::{Context, Result};
|
||||||
use bimap::BiMap;
|
use bimap::BiMap;
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
use russh::server::Session;
|
use russh::server::Session;
|
||||||
use russh::{CryptoVec, Sig};
|
use russh::{CryptoVec, MethodSet, Sig};
|
||||||
use russh_keys::key::PublicKey;
|
use russh_keys::key::PublicKey;
|
||||||
use russh_keys::PublicKeyBase64;
|
use russh_keys::PublicKeyBase64;
|
||||||
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
||||||
use tokio::sync::{oneshot, Mutex};
|
use tokio::sync::{broadcast, oneshot, Mutex};
|
||||||
use tracing::*;
|
use tracing::*;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState, CredentialKind};
|
use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState, CredentialKind};
|
||||||
|
@ -53,6 +53,12 @@ enum Event {
|
||||||
Client(RCEvent),
|
Client(RCEvent),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum KeyboardInteractiveState {
|
||||||
|
None,
|
||||||
|
OtpRequested,
|
||||||
|
WebAuthRequested(broadcast::Receiver<AuthResult>),
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ServerSession {
|
pub struct ServerSession {
|
||||||
pub id: SessionId,
|
pub id: SessionId,
|
||||||
username: Option<String>,
|
username: Option<String>,
|
||||||
|
@ -74,7 +80,8 @@ pub struct ServerSession {
|
||||||
hub: EventHub<Event>,
|
hub: EventHub<Event>,
|
||||||
event_sender: EventSender<Event>,
|
event_sender: EventSender<Event>,
|
||||||
service_output: ServiceOutput,
|
service_output: ServiceOutput,
|
||||||
auth_state: Option<AuthState>,
|
auth_state: Option<Arc<Mutex<AuthState>>>,
|
||||||
|
keyboard_interactive_state: KeyboardInteractiveState,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn session_debug_tag(id: &SessionId, remote_address: &SocketAddr) -> String {
|
fn session_debug_tag(id: &SessionId, remote_address: &SocketAddr) -> String {
|
||||||
|
@ -142,6 +149,7 @@ impl ServerSession {
|
||||||
so_tx.send(BytesMut::from(data).freeze()).context("x")
|
so_tx.send(BytesMut::from(data).freeze()).context("x")
|
||||||
})),
|
})),
|
||||||
auth_state: None,
|
auth_state: None,
|
||||||
|
keyboard_interactive_state: KeyboardInteractiveState::None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let this = Arc::new(Mutex::new(this));
|
let this = Arc::new(Mutex::new(this));
|
||||||
|
@ -217,18 +225,23 @@ impl ServerSession {
|
||||||
Ok(this)
|
Ok(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_auth_state(&mut self, username: &str) -> Result<&mut AuthState> {
|
async fn get_auth_state(&mut self, username: &str) -> Result<Arc<Mutex<AuthState>>> {
|
||||||
#[allow(clippy::unwrap_used)]
|
#[allow(clippy::unwrap_used)]
|
||||||
if self.auth_state.is_none() || self.auth_state.as_ref().unwrap().username() != username {
|
if self.auth_state.is_none()
|
||||||
let mut cp = self.services.config_provider.lock().await;
|
|| self.auth_state.as_ref().unwrap().lock().await.username() != username
|
||||||
self.auth_state = Some(AuthState::new(
|
{
|
||||||
username.to_string(),
|
let state = self
|
||||||
crate::PROTOCOL_NAME.to_string(),
|
.services
|
||||||
cp.get_credential_policy(username).await?,
|
.auth_state_store
|
||||||
));
|
.lock()
|
||||||
|
.await
|
||||||
|
.create(username, crate::PROTOCOL_NAME)
|
||||||
|
.await?
|
||||||
|
.1;
|
||||||
|
self.auth_state = Some(state);
|
||||||
}
|
}
|
||||||
#[allow(clippy::unwrap_used)]
|
#[allow(clippy::unwrap_used)]
|
||||||
Ok(self.auth_state.as_mut().unwrap())
|
Ok(self.auth_state.as_ref().map(Clone::clone).unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn make_logging_span(&self) -> tracing::Span {
|
pub fn make_logging_span(&self) -> tracing::Span {
|
||||||
|
@ -939,13 +952,17 @@ impl ServerSession {
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
||||||
Ok(AuthResult::Rejected) => russh::server::Auth::Reject,
|
Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
|
||||||
Ok(AuthResult::Need(_) | AuthResult::NeedMoreCredentials) => {
|
proceed_with_methods: Some(MethodSet::all()),
|
||||||
russh::server::Auth::Reject
|
},
|
||||||
}
|
Ok(AuthResult::Need(kinds)) => russh::server::Auth::Reject {
|
||||||
|
proceed_with_methods: Some(self.get_remaining_auth_methods(kinds)),
|
||||||
|
},
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
error!(?error, "Failed to verify credentials");
|
error!(?error, "Failed to verify credentials");
|
||||||
russh::server::Auth::Reject
|
russh::server::Auth::Reject {
|
||||||
|
proceed_with_methods: None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -963,13 +980,17 @@ impl ServerSession {
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
||||||
Ok(AuthResult::Rejected) => russh::server::Auth::Reject,
|
Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
|
||||||
Ok(AuthResult::Need(_) | AuthResult::NeedMoreCredentials) => {
|
proceed_with_methods: None,
|
||||||
russh::server::Auth::Reject
|
},
|
||||||
}
|
Ok(AuthResult::Need(_)) => russh::server::Auth::Reject {
|
||||||
|
proceed_with_methods: None,
|
||||||
|
},
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
error!(?error, "Failed to verify credentials");
|
error!(?error, "Failed to verify credentials");
|
||||||
russh::server::Auth::Reject
|
russh::server::Auth::Reject {
|
||||||
|
proceed_with_methods: None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -982,27 +1003,111 @@ impl ServerSession {
|
||||||
let selector: AuthSelector = ssh_username.expose_secret().into();
|
let selector: AuthSelector = ssh_username.expose_secret().into();
|
||||||
info!("Keyboard-interactive auth as {:?}", selector);
|
info!("Keyboard-interactive auth as {:?}", selector);
|
||||||
|
|
||||||
let cred = response.map(AuthCredential::Otp);
|
let cred;
|
||||||
|
match &mut self.keyboard_interactive_state {
|
||||||
|
KeyboardInteractiveState::None => {
|
||||||
|
cred = None;
|
||||||
|
}
|
||||||
|
KeyboardInteractiveState::OtpRequested => {
|
||||||
|
cred = response.map(AuthCredential::Otp);
|
||||||
|
}
|
||||||
|
KeyboardInteractiveState::WebAuthRequested(event) => {
|
||||||
|
cred = None;
|
||||||
|
let _ = event.recv().await;
|
||||||
|
// the auth state has been updated by now
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.keyboard_interactive_state = KeyboardInteractiveState::None;
|
||||||
|
|
||||||
match self.try_auth(&selector, cred).await {
|
match self.try_auth(&selector, cred).await {
|
||||||
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
||||||
Ok(
|
Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
|
||||||
AuthResult::Rejected
|
proceed_with_methods: None,
|
||||||
| AuthResult::NeedMoreCredentials
|
|
||||||
| AuthResult::Need(CredentialKind::Otp),
|
|
||||||
) => russh::server::Auth::Partial {
|
|
||||||
name: Cow::Borrowed("Two-factor authentication"),
|
|
||||||
instructions: Cow::Borrowed(""),
|
|
||||||
prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]),
|
|
||||||
},
|
},
|
||||||
Ok(AuthResult::Need(_)) => russh::server::Auth::Reject, // TODO SSO
|
Ok(AuthResult::Need(kinds)) => {
|
||||||
|
if kinds.contains(&CredentialKind::Otp) {
|
||||||
|
self.keyboard_interactive_state = KeyboardInteractiveState::OtpRequested;
|
||||||
|
russh::server::Auth::Partial {
|
||||||
|
name: Cow::Borrowed("Two-factor authentication"),
|
||||||
|
instructions: Cow::Borrowed(""),
|
||||||
|
prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]),
|
||||||
|
}
|
||||||
|
} else if kinds.contains(&CredentialKind::WebUserApproval) {
|
||||||
|
let Some(auth_state) = self.auth_state.as_ref() else {
|
||||||
|
return russh::server::Auth::Reject { proceed_with_methods: None};
|
||||||
|
};
|
||||||
|
let auth_state_id = *auth_state.lock().await.id();
|
||||||
|
let event = self
|
||||||
|
.services
|
||||||
|
.auth_state_store
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.subscribe(auth_state_id);
|
||||||
|
self.keyboard_interactive_state =
|
||||||
|
KeyboardInteractiveState::WebAuthRequested(event);
|
||||||
|
|
||||||
|
let mut login_url = match self
|
||||||
|
.services
|
||||||
|
.config
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.construct_external_url(None)
|
||||||
|
{
|
||||||
|
Ok(url) => url,
|
||||||
|
Err(error) => {
|
||||||
|
error!(?error, "Failed to construct external URL");
|
||||||
|
return russh::server::Auth::Reject {
|
||||||
|
proceed_with_methods: None,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
login_url.set_path("@warpgate");
|
||||||
|
login_url
|
||||||
|
.set_fragment(Some(&format!("/login?next=%2Flogin%2F{auth_state_id}")));
|
||||||
|
|
||||||
|
russh::server::Auth::Partial {
|
||||||
|
name: Cow::Owned(format!(
|
||||||
|
concat!(
|
||||||
|
"----------------------------------------------------------------\n",
|
||||||
|
"Warpgate authentication: please open {} in your browser\n",
|
||||||
|
"----------------------------------------------------------------\n"
|
||||||
|
),
|
||||||
|
login_url
|
||||||
|
)),
|
||||||
|
instructions: Cow::Borrowed(""),
|
||||||
|
prompts: Cow::Owned(vec![(Cow::Borrowed("Press Enter when done: "), true)]),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
russh::server::Auth::Reject {
|
||||||
|
proceed_with_methods: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
error!(?error, "Failed to verify credentials");
|
error!(?error, "Failed to verify credentials");
|
||||||
russh::server::Auth::Reject
|
russh::server::Auth::Reject {
|
||||||
|
proceed_with_methods: None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_remaining_auth_methods(&self, kinds: HashSet<CredentialKind>) -> MethodSet {
|
||||||
|
let mut m = MethodSet::empty();
|
||||||
|
for kind in kinds {
|
||||||
|
match kind {
|
||||||
|
CredentialKind::Password => m.insert(MethodSet::PASSWORD),
|
||||||
|
CredentialKind::Otp => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
|
||||||
|
CredentialKind::WebUserApproval => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
|
||||||
|
CredentialKind::PublicKey => m.insert(MethodSet::PUBLICKEY),
|
||||||
|
CredentialKind::Sso => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m
|
||||||
|
}
|
||||||
|
|
||||||
async fn try_auth(
|
async fn try_auth(
|
||||||
&mut self,
|
&mut self,
|
||||||
selector: &AuthSelector,
|
selector: &AuthSelector,
|
||||||
|
@ -1014,7 +1119,9 @@ impl ServerSession {
|
||||||
target_name,
|
target_name,
|
||||||
} => {
|
} => {
|
||||||
let cp = self.services.config_provider.clone();
|
let cp = self.services.config_provider.clone();
|
||||||
let state = self.get_auth_state(username).await?;
|
|
||||||
|
let state_arc = self.get_auth_state(username).await?;
|
||||||
|
let mut state = state_arc.lock().await;
|
||||||
|
|
||||||
if let Some(credential) = credential {
|
if let Some(credential) = credential {
|
||||||
if cp
|
if cp
|
||||||
|
@ -1031,6 +1138,12 @@ impl ServerSession {
|
||||||
|
|
||||||
match user_auth_result {
|
match user_auth_result {
|
||||||
AuthResult::Accepted { username } => {
|
AuthResult::Accepted { username } => {
|
||||||
|
self.services
|
||||||
|
.auth_state_store
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.complete(state.id())
|
||||||
|
.await;
|
||||||
let target_auth_result = {
|
let target_auth_result = {
|
||||||
self.services
|
self.services
|
||||||
.config_provider
|
.config_provider
|
||||||
|
|
|
@ -2,22 +2,25 @@
|
||||||
import { faSignOut } from '@fortawesome/free-solid-svg-icons'
|
import { faSignOut } from '@fortawesome/free-solid-svg-icons'
|
||||||
import { Alert, Spinner } from 'sveltestrap'
|
import { Alert, Spinner } from 'sveltestrap'
|
||||||
import Fa from 'svelte-fa'
|
import Fa from 'svelte-fa'
|
||||||
|
import Router, { push } from 'svelte-spa-router'
|
||||||
|
import { wrap } from 'svelte-spa-router/wrap'
|
||||||
|
import { get } from 'svelte/store'
|
||||||
import { api } from 'gateway/lib/api'
|
import { api } from 'gateway/lib/api'
|
||||||
import { reloadServerInfo, serverInfo } from 'gateway/lib/store'
|
import { reloadServerInfo, serverInfo } from 'gateway/lib/store'
|
||||||
import ThemeSwitcher from 'common/ThemeSwitcher.svelte'
|
import ThemeSwitcher from 'common/ThemeSwitcher.svelte'
|
||||||
import Login from './Login.svelte'
|
|
||||||
import TargetList from './TargetList.svelte'
|
|
||||||
import Logo from 'common/Logo.svelte'
|
import Logo from 'common/Logo.svelte'
|
||||||
|
|
||||||
let redirecting = false
|
let redirecting = false
|
||||||
|
let serverInfoPromise = reloadServerInfo()
|
||||||
|
|
||||||
async function init () {
|
async function init () {
|
||||||
await reloadServerInfo()
|
await serverInfoPromise
|
||||||
}
|
}
|
||||||
|
|
||||||
async function logout () {
|
async function logout () {
|
||||||
await api.logout()
|
await api.logout()
|
||||||
await reloadServerInfo()
|
await reloadServerInfo()
|
||||||
|
push('/login')
|
||||||
}
|
}
|
||||||
|
|
||||||
function onPageResume () {
|
function onPageResume () {
|
||||||
|
@ -25,6 +28,36 @@ function onPageResume () {
|
||||||
init()
|
init()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function requireLogin (detail) {
|
||||||
|
await serverInfoPromise
|
||||||
|
if (!get(serverInfo)?.username) {
|
||||||
|
let url = detail.location
|
||||||
|
if (detail.querystring) {
|
||||||
|
url += '?' + detail.querystring
|
||||||
|
}
|
||||||
|
push('/login?next=' + encodeURIComponent(url))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
const routes = {
|
||||||
|
'/': wrap({
|
||||||
|
asyncComponent: () => import('./TargetList.svelte'),
|
||||||
|
props: {
|
||||||
|
'on:navigation': () => redirecting = true,
|
||||||
|
},
|
||||||
|
conditions: [requireLogin],
|
||||||
|
}),
|
||||||
|
'/login': wrap({
|
||||||
|
asyncComponent: () => import('./Login.svelte'),
|
||||||
|
}),
|
||||||
|
'/login/:stateId': wrap({
|
||||||
|
asyncComponent: () => import('./OutOfBandAuth.svelte'),
|
||||||
|
conditions: [requireLogin],
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
init()
|
init()
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
@ -38,9 +71,9 @@ init()
|
||||||
<Spinner />
|
<Spinner />
|
||||||
{:else}
|
{:else}
|
||||||
<div class="d-flex align-items-center mt-5 mb-5">
|
<div class="d-flex align-items-center mt-5 mb-5">
|
||||||
<div class="logo">
|
<a class="logo" href="/@warpgate">
|
||||||
<Logo />
|
<Logo />
|
||||||
</div>
|
</a>
|
||||||
|
|
||||||
{#if $serverInfo?.username}
|
{#if $serverInfo?.username}
|
||||||
<div class="ms-auto">
|
<div class="ms-auto">
|
||||||
|
@ -56,12 +89,7 @@ init()
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<main>
|
<main>
|
||||||
{#if $serverInfo?.username}
|
<Router {routes}/>
|
||||||
<TargetList
|
|
||||||
on:navigation={() => redirecting = true} />
|
|
||||||
{:else}
|
|
||||||
<Login />
|
|
||||||
{/if}
|
|
||||||
</main>
|
</main>
|
||||||
|
|
||||||
<footer class="mt-5">
|
<footer class="mt-5">
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { replace } from 'svelte-spa-router'
|
import { get } from 'svelte/store'
|
||||||
|
import { querystring, replace } from 'svelte-spa-router'
|
||||||
import { Alert, FormGroup, Spinner } from 'sveltestrap'
|
import { Alert, FormGroup, Spinner } from 'sveltestrap'
|
||||||
import Fa from 'svelte-fa'
|
import Fa from 'svelte-fa'
|
||||||
import { faArrowRight } from '@fortawesome/free-solid-svg-icons'
|
import { faArrowRight } from '@fortawesome/free-solid-svg-icons'
|
||||||
|
@ -9,25 +10,47 @@ import { api, ApiAuthState, LoginFailureResponseFromJSON, SsoProviderDescription
|
||||||
import { reloadServerInfo } from 'gateway/lib/store'
|
import { reloadServerInfo } from 'gateway/lib/store'
|
||||||
import AsyncButton from 'common/AsyncButton.svelte'
|
import AsyncButton from 'common/AsyncButton.svelte'
|
||||||
|
|
||||||
|
export let params: { stateId?: string } = {}
|
||||||
|
|
||||||
let error: Error|null = null
|
let error: Error|null = null
|
||||||
let username = ''
|
let username = ''
|
||||||
let password = ''
|
let password = ''
|
||||||
let otp = ''
|
let otp = ''
|
||||||
let busy = false
|
let busy = false
|
||||||
|
|
||||||
let authState = ApiAuthState.NotStarted
|
let authState: ApiAuthState|undefined = undefined
|
||||||
|
|
||||||
let ssoProvidersPromise = api.getSsoProviders()
|
let ssoProvidersPromise = api.getSsoProviders()
|
||||||
|
|
||||||
const nextURL = new URLSearchParams(location.search).get('next')
|
const nextURL = new URLSearchParams(get(querystring)).get('next') ?? undefined
|
||||||
const serverErrorMessage = new URLSearchParams(location.search).get('login_error')
|
const serverErrorMessage = new URLSearchParams(location.search).get('login_error')
|
||||||
|
|
||||||
async function init () {
|
async function init () {
|
||||||
authState = (await api.getAuthState()).state
|
try {
|
||||||
|
authState = (await api.getDefaultAuthState()).state
|
||||||
|
} catch (err) {
|
||||||
|
if (err.status) {
|
||||||
|
const response = err as Response
|
||||||
|
if (response.status === 404) {
|
||||||
|
authState = ApiAuthState.NotStarted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
continueWithState()
|
continueWithState()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function success () {
|
||||||
|
if (nextURL) {
|
||||||
|
replace(nextURL)
|
||||||
|
} else {
|
||||||
|
replace('/')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function continueWithState () {
|
async function continueWithState () {
|
||||||
|
if (authState === ApiAuthState.Success) {
|
||||||
|
success()
|
||||||
|
}
|
||||||
if (authState === ApiAuthState.SsoNeeded) {
|
if (authState === ApiAuthState.SsoNeeded) {
|
||||||
const providers = await ssoProvidersPromise
|
const providers = await ssoProvidersPromise
|
||||||
if (!providers.length) {
|
if (!providers.length) {
|
||||||
|
@ -65,18 +88,15 @@ async function _login () {
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if (nextURL) {
|
await reloadServerInfo()
|
||||||
location.href = nextURL
|
success()
|
||||||
} else {
|
|
||||||
await reloadServerInfo()
|
|
||||||
replace('/')
|
|
||||||
}
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err.status) {
|
if (err.status) {
|
||||||
const response = err as Response
|
const response = err as Response
|
||||||
if (response.status === 401) {
|
if (response.status === 401) {
|
||||||
const failure = LoginFailureResponseFromJSON(await response.json())
|
const failure = LoginFailureResponseFromJSON(await response.json())
|
||||||
authState = failure.state
|
authState = failure.state
|
||||||
|
|
||||||
continueWithState()
|
continueWithState()
|
||||||
} else {
|
} else {
|
||||||
error = new Error(await response.text())
|
error = new Error(await response.text())
|
||||||
|
@ -96,8 +116,8 @@ function onInputKey (event: KeyboardEvent) {
|
||||||
async function startSSO (provider: SsoProviderDescription) {
|
async function startSSO (provider: SsoProviderDescription) {
|
||||||
busy = true
|
busy = true
|
||||||
try {
|
try {
|
||||||
const params = await api.startSso(provider)
|
const p = await api.startSso({ name: provider.name, next: nextURL })
|
||||||
location.href = params.url
|
location.href = p.url
|
||||||
} catch {
|
} catch {
|
||||||
busy = false
|
busy = false
|
||||||
}
|
}
|
||||||
|
@ -115,6 +135,10 @@ async function startSSO (provider: SsoProviderDescription) {
|
||||||
<h1>Continue login</h1>
|
<h1>Continue login</h1>
|
||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
{#if params.stateId}
|
||||||
|
//todo
|
||||||
|
loggin in for auth state id {params.stateId}
|
||||||
|
{/if}
|
||||||
{#if authState === ApiAuthState.OtpNeeded}
|
{#if authState === ApiAuthState.OtpNeeded}
|
||||||
<FormGroup floating label="One-time password">
|
<FormGroup floating label="One-time password">
|
||||||
<!-- svelte-ignore a11y-autofocus -->
|
<!-- svelte-ignore a11y-autofocus -->
|
||||||
|
|
65
warpgate-web/src/gateway/OutOfBandAuth.svelte
Normal file
65
warpgate-web/src/gateway/OutOfBandAuth.svelte
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
<script lang="ts">
|
||||||
|
import { Alert, Spinner } from 'sveltestrap'
|
||||||
|
|
||||||
|
import { api, ApiAuthState, AuthStateResponseInternal } from 'gateway/lib/api'
|
||||||
|
import AsyncButton from 'common/AsyncButton.svelte'
|
||||||
|
|
||||||
|
export let params: { stateId: string }
|
||||||
|
let authState: AuthStateResponseInternal
|
||||||
|
|
||||||
|
async function reload () {
|
||||||
|
authState = await api.getAuthState({ id: params.stateId })
|
||||||
|
}
|
||||||
|
|
||||||
|
async function init () {
|
||||||
|
await reload()
|
||||||
|
}
|
||||||
|
|
||||||
|
async function approve () {
|
||||||
|
api.approveAuth({ id: params.stateId })
|
||||||
|
await reload()
|
||||||
|
}
|
||||||
|
|
||||||
|
async function reject () {
|
||||||
|
api.rejectAuth({ id: params.stateId })
|
||||||
|
await reload()
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
|
{#await init()}
|
||||||
|
<Spinner />
|
||||||
|
{:then}
|
||||||
|
<div class="page-summary-bar">
|
||||||
|
<h1>Authorization request</h1>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p>Authorize this {authState.protocol} session?</p>
|
||||||
|
|
||||||
|
{#if authState.state === ApiAuthState.Success}
|
||||||
|
<Alert color="success">
|
||||||
|
Approved
|
||||||
|
</Alert>
|
||||||
|
{:else if authState.state === ApiAuthState.Failed}
|
||||||
|
<Alert color="danger">
|
||||||
|
Rejected
|
||||||
|
</Alert>
|
||||||
|
{:else}
|
||||||
|
<div class="d-flex">
|
||||||
|
<AsyncButton
|
||||||
|
color="primary"
|
||||||
|
class="d-flex align-items-center ms-auto"
|
||||||
|
click={approve}
|
||||||
|
>
|
||||||
|
Authorize
|
||||||
|
</AsyncButton>
|
||||||
|
<AsyncButton
|
||||||
|
outline
|
||||||
|
color="secondary"
|
||||||
|
class="d-flex align-items-center ms-2"
|
||||||
|
click={reject}
|
||||||
|
>
|
||||||
|
Reject
|
||||||
|
</AsyncButton>
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
{/await}
|
|
@ -71,6 +71,16 @@
|
||||||
"operationId": "otpLogin"
|
"operationId": "otpLogin"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/auth/logout": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"201": {
|
||||||
|
"description": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "logout"
|
||||||
|
}
|
||||||
|
},
|
||||||
"/auth/state": {
|
"/auth/state": {
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -83,19 +93,108 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
},
|
"404": {
|
||||||
"operationId": "getAuthState"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/auth/logout": {
|
|
||||||
"post": {
|
|
||||||
"responses": {
|
|
||||||
"201": {
|
|
||||||
"description": ""
|
"description": ""
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"operationId": "logout"
|
"operationId": "getDefaultAuthState"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/auth/state/{id}": {
|
||||||
|
"get": {
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "id",
|
||||||
|
"schema": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "uuid"
|
||||||
|
},
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"deprecated": false
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/AuthStateResponseInternal"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"404": {
|
||||||
|
"description": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "get_auth_state"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/auth/state/{id}/approve": {
|
||||||
|
"post": {
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "id",
|
||||||
|
"schema": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "uuid"
|
||||||
|
},
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"deprecated": false
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/AuthStateResponseInternal"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"404": {
|
||||||
|
"description": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "approve_auth"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/auth/state/{id}/reject": {
|
||||||
|
"post": {
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "id",
|
||||||
|
"schema": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "uuid"
|
||||||
|
},
|
||||||
|
"in": "path",
|
||||||
|
"required": true,
|
||||||
|
"deprecated": false
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/AuthStateResponseInternal"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"404": {
|
||||||
|
"description": ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"operationId": "reject_auth"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/info": {
|
"/info": {
|
||||||
|
@ -187,6 +286,15 @@
|
||||||
"in": "path",
|
"in": "path",
|
||||||
"required": true,
|
"required": true,
|
||||||
"deprecated": false
|
"deprecated": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "next",
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"in": "query",
|
||||||
|
"required": false,
|
||||||
|
"deprecated": false
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -218,15 +326,21 @@
|
||||||
"PasswordNeeded",
|
"PasswordNeeded",
|
||||||
"OtpNeeded",
|
"OtpNeeded",
|
||||||
"SsoNeeded",
|
"SsoNeeded",
|
||||||
|
"WebUserApprovalNeeded",
|
||||||
|
"PublicKeyNeeded",
|
||||||
"Success"
|
"Success"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"AuthStateResponseInternal": {
|
"AuthStateResponseInternal": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
"protocol",
|
||||||
"state"
|
"state"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"protocol": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
"state": {
|
"state": {
|
||||||
"$ref": "#/components/schemas/ApiAuthState"
|
"$ref": "#/components/schemas/ApiAuthState"
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,9 @@ pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Checking MySQL key".to_string())?;
|
.with_context(|| "Checking MySQL key".to_string())?;
|
||||||
}
|
}
|
||||||
|
if !config.store.sso_providers.is_empty() && config.store.external_host.is_none() {
|
||||||
|
anyhow::bail!("SSO requires the external_host config option");
|
||||||
|
}
|
||||||
info!("No problems found");
|
info!("No problems found");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,7 +77,6 @@ fn check_and_migrate_config(store: &mut serde_yaml::Value) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn watch_config<P: AsRef<Path> + Send + 'static>(
|
pub fn watch_config<P: AsRef<Path> + Send + 'static>(
|
||||||
path: P,
|
path: P,
|
||||||
config: Arc<Mutex<WarpgateConfig>>,
|
config: Arc<Mutex<WarpgateConfig>>,
|
||||||
|
|
Loading…
Reference in a new issue