mirror of
https://github.com/warp-tech/warpgate.git
synced 2024-09-20 06:46:17 +08:00
parent
c6885f18c3
commit
c8a004f756
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
|
@ -12,6 +12,7 @@ jobs:
|
||||||
|
|
||||||
- uses: actions-rs/toolchain@v1
|
- uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
|
toolchain: nightly
|
||||||
target: x86_64-unknown-linux-gnu
|
target: x86_64-unknown-linux-gnu
|
||||||
override: true
|
override: true
|
||||||
|
|
||||||
|
@ -41,6 +42,7 @@ jobs:
|
||||||
uses: actions-rs/cargo@v1
|
uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
command: build
|
command: build
|
||||||
|
toolchain: nightly
|
||||||
use-cross: true
|
use-cross: true
|
||||||
args: --release --target x86_64-unknown-linux-gnu -Ztarget-applies-to-host
|
args: --release --target x86_64-unknown-linux-gnu -Ztarget-applies-to-host
|
||||||
|
|
||||||
|
|
6
Cargo.lock
generated
6
Cargo.lock
generated
|
@ -3143,9 +3143,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "russh"
|
name = "russh"
|
||||||
version = "0.34.0-beta.8"
|
version = "0.34.0-beta.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ccd8be93ee0b54a8a6b74c77ecef946185f0acbfb5234ea66666887621381e85"
|
checksum = "4f3bb72a66e32d52e0e258627d141d5c93b408e050f15033699caa836d064c7e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aes 0.8.1",
|
"aes 0.8.1",
|
||||||
"aes-gcm 0.10.1",
|
"aes-gcm 0.10.1",
|
||||||
|
@ -4620,7 +4620,6 @@ dependencies = [
|
||||||
"bytes 1.2.1",
|
"bytes 1.2.1",
|
||||||
"chrono",
|
"chrono",
|
||||||
"data-encoding",
|
"data-encoding",
|
||||||
"futures",
|
|
||||||
"humantime-serde",
|
"humantime-serde",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
@ -4692,7 +4691,6 @@ version = "0.4.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"chrono",
|
|
||||||
"cookie",
|
"cookie",
|
||||||
"data-encoding",
|
"data-encoding",
|
||||||
"delegate",
|
"delegate",
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
nightly-2022-08-01
|
|
2
rust-toolchain.toml
Normal file
2
rust-toolchain.toml
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
[toolchain]
|
||||||
|
channel = "nightly-2022-07-22"
|
|
@ -13,7 +13,6 @@ chrono = { version = "0.4", features = ["serde"] }
|
||||||
data-encoding = "2.3"
|
data-encoding = "2.3"
|
||||||
humantime-serde = "1.1"
|
humantime-serde = "1.1"
|
||||||
lazy_static = "1.4"
|
lazy_static = "1.4"
|
||||||
futures = "0.3"
|
|
||||||
once_cell = "1.10"
|
once_cell = "1.10"
|
||||||
packet = "0.1"
|
packet = "0.1"
|
||||||
password-hash = "0.4"
|
password-hash = "0.4"
|
||||||
|
|
|
@ -13,8 +13,6 @@ pub enum CredentialKind {
|
||||||
Otp,
|
Otp,
|
||||||
#[serde(rename = "sso")]
|
#[serde(rename = "sso")]
|
||||||
Sso,
|
Sso,
|
||||||
#[serde(rename = "web")]
|
|
||||||
WebUserApproval,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
@ -29,7 +27,6 @@ pub enum AuthCredential {
|
||||||
provider: String,
|
provider: String,
|
||||||
email: String,
|
email: String,
|
||||||
},
|
},
|
||||||
WebUserApproval,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AuthCredential {
|
impl AuthCredential {
|
||||||
|
@ -39,7 +36,6 @@ impl AuthCredential {
|
||||||
Self::PublicKey { .. } => CredentialKind::PublicKey,
|
Self::PublicKey { .. } => CredentialKind::PublicKey,
|
||||||
Self::Otp { .. } => CredentialKind::Otp,
|
Self::Otp { .. } => CredentialKind::Otp,
|
||||||
Self::Sso { .. } => CredentialKind::Sso,
|
Self::Sso { .. } => CredentialKind::Sso,
|
||||||
Self::WebUserApproval => CredentialKind::WebUserApproval,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use super::{AuthCredential, CredentialKind};
|
use super::{AuthCredential, CredentialKind};
|
||||||
|
use crate::UserRequireCredentialsPolicy;
|
||||||
|
|
||||||
pub enum CredentialPolicyResponse {
|
pub enum CredentialPolicyResponse {
|
||||||
Ok,
|
Ok,
|
||||||
Need(HashSet<CredentialKind>),
|
NeedMoreCredentials,
|
||||||
|
Need(CredentialKind),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait CredentialPolicy {
|
pub trait CredentialPolicy {
|
||||||
|
@ -15,71 +17,36 @@ pub trait CredentialPolicy {
|
||||||
) -> CredentialPolicyResponse;
|
) -> CredentialPolicyResponse;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct AnySingleCredentialPolicy {
|
impl CredentialPolicy for UserRequireCredentialsPolicy {
|
||||||
pub supported_credential_types: HashSet<CredentialKind>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct AllCredentialsPolicy {
|
|
||||||
pub required_credential_types: HashSet<CredentialKind>,
|
|
||||||
pub supported_credential_types: HashSet<CredentialKind>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct PerProtocolCredentialPolicy {
|
|
||||||
pub protocols: HashMap<&'static str, Box<dyn CredentialPolicy + Send + Sync>>,
|
|
||||||
pub default: Box<dyn CredentialPolicy + Send + Sync>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CredentialPolicy for AnySingleCredentialPolicy {
|
|
||||||
fn is_sufficient(
|
|
||||||
&self,
|
|
||||||
_protocol: &str,
|
|
||||||
valid_credentials: &[AuthCredential],
|
|
||||||
) -> CredentialPolicyResponse {
|
|
||||||
if valid_credentials.is_empty() {
|
|
||||||
CredentialPolicyResponse::Need(
|
|
||||||
self.supported_credential_types
|
|
||||||
.clone()
|
|
||||||
.into_iter()
|
|
||||||
.collect(),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
CredentialPolicyResponse::Ok
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CredentialPolicy for AllCredentialsPolicy {
|
|
||||||
fn is_sufficient(
|
|
||||||
&self,
|
|
||||||
_protocol: &str,
|
|
||||||
valid_credentials: &[AuthCredential],
|
|
||||||
) -> CredentialPolicyResponse {
|
|
||||||
let valid_credential_types: HashSet<CredentialKind> =
|
|
||||||
valid_credentials.iter().map(|x| x.kind()).collect();
|
|
||||||
|
|
||||||
if valid_credential_types.is_superset(&self.required_credential_types) {
|
|
||||||
CredentialPolicyResponse::Ok
|
|
||||||
} else {
|
|
||||||
CredentialPolicyResponse::Need(
|
|
||||||
self.required_credential_types
|
|
||||||
.difference(&valid_credential_types)
|
|
||||||
.cloned()
|
|
||||||
.collect(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CredentialPolicy for PerProtocolCredentialPolicy {
|
|
||||||
fn is_sufficient(
|
fn is_sufficient(
|
||||||
&self,
|
&self,
|
||||||
protocol: &str,
|
protocol: &str,
|
||||||
valid_credentials: &[AuthCredential],
|
valid_credentials: &[AuthCredential],
|
||||||
) -> CredentialPolicyResponse {
|
) -> CredentialPolicyResponse {
|
||||||
if let Some(policy) = self.protocols.get(protocol) {
|
let required_kinds = match protocol {
|
||||||
policy.is_sufficient(protocol, valid_credentials)
|
"SSH" => &self.ssh,
|
||||||
|
"HTTP" => &self.http,
|
||||||
|
"MySQL" => &self.mysql,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
if let Some(required_kinds) = required_kinds {
|
||||||
|
let mut remaining_required_kinds = HashSet::<CredentialKind>::new();
|
||||||
|
remaining_required_kinds.extend(required_kinds);
|
||||||
|
for kind in required_kinds {
|
||||||
|
if valid_credentials.iter().any(|x| x.kind() == *kind) {
|
||||||
|
remaining_required_kinds.remove(kind);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(kind) = remaining_required_kinds.into_iter().next() {
|
||||||
|
CredentialPolicyResponse::Need(kind)
|
||||||
} else {
|
} else {
|
||||||
self.default.is_sufficient(protocol, valid_credentials)
|
CredentialPolicyResponse::Ok
|
||||||
|
}
|
||||||
|
} else if valid_credentials.is_empty() {
|
||||||
|
CredentialPolicyResponse::NeedMoreCredentials
|
||||||
|
} else {
|
||||||
|
CredentialPolicyResponse::Ok
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,66 +1,71 @@
|
||||||
use uuid::Uuid;
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
use super::{AuthCredential, CredentialPolicy, CredentialPolicyResponse};
|
use super::{AuthCredential, CredentialPolicy, CredentialPolicyResponse};
|
||||||
use crate::AuthResult;
|
use crate::AuthResult;
|
||||||
|
|
||||||
|
#[allow(clippy::unwrap_used)]
|
||||||
|
pub static TIMEOUT: Lazy<Duration> = Lazy::new(|| Duration::from_secs(60 * 10));
|
||||||
|
|
||||||
pub struct AuthState {
|
pub struct AuthState {
|
||||||
id: Uuid,
|
|
||||||
username: String,
|
username: String,
|
||||||
protocol: String,
|
protocol: String,
|
||||||
force_rejected: bool,
|
policy: Option<Box<dyn CredentialPolicy + Sync + Send>>,
|
||||||
policy: Box<dyn CredentialPolicy + Sync + Send>,
|
|
||||||
valid_credentials: Vec<AuthCredential>,
|
valid_credentials: Vec<AuthCredential>,
|
||||||
|
started_at: Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AuthState {
|
impl AuthState {
|
||||||
pub(crate) fn new(
|
pub fn new(
|
||||||
id: Uuid,
|
|
||||||
username: String,
|
username: String,
|
||||||
protocol: String,
|
protocol: String,
|
||||||
policy: Box<dyn CredentialPolicy + Sync + Send>,
|
policy: Option<Box<dyn CredentialPolicy + Sync + Send>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id,
|
|
||||||
username,
|
username,
|
||||||
protocol,
|
protocol,
|
||||||
force_rejected: false,
|
|
||||||
policy,
|
policy,
|
||||||
valid_credentials: vec![],
|
valid_credentials: vec![],
|
||||||
|
started_at: Instant::now(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn id(&self) -> &Uuid {
|
|
||||||
&self.id
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn username(&self) -> &str {
|
pub fn username(&self) -> &str {
|
||||||
&self.username
|
&self.username
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn protocol(&self) -> &str {
|
|
||||||
&self.protocol
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_valid_credential(&mut self, credential: AuthCredential) {
|
pub fn add_valid_credential(&mut self, credential: AuthCredential) {
|
||||||
self.valid_credentials.push(credential);
|
self.valid_credentials.push(credential);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reject(&mut self) {
|
pub fn is_expired(&self) -> bool {
|
||||||
self.force_rejected = true;
|
self.started_at.elapsed() > *TIMEOUT
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn verify(&self) -> AuthResult {
|
pub fn verify(&self) -> AuthResult {
|
||||||
if self.force_rejected {
|
if self.valid_credentials.is_empty() {
|
||||||
|
warn!(
|
||||||
|
username=%self.username,
|
||||||
|
"No matching valid credentials"
|
||||||
|
);
|
||||||
return AuthResult::Rejected;
|
return AuthResult::Rejected;
|
||||||
}
|
}
|
||||||
match self
|
|
||||||
.policy
|
if let Some(ref policy) = self.policy {
|
||||||
.is_sufficient(&self.protocol, &self.valid_credentials[..])
|
match policy.is_sufficient(&self.protocol, &self.valid_credentials[..]) {
|
||||||
{
|
CredentialPolicyResponse::Ok => {}
|
||||||
CredentialPolicyResponse::Ok => AuthResult::Accepted {
|
CredentialPolicyResponse::Need(kind) => {
|
||||||
|
return AuthResult::Need(kind);
|
||||||
|
}
|
||||||
|
CredentialPolicyResponse::NeedMoreCredentials => {
|
||||||
|
return AuthResult::Rejected;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
AuthResult::Accepted {
|
||||||
username: self.username.clone(),
|
username: self.username.clone(),
|
||||||
},
|
|
||||||
CredentialPolicyResponse::Need(kinds) => AuthResult::Need(kinds),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,32 +1,15 @@
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
|
||||||
|
|
||||||
use once_cell::sync::Lazy;
|
use tokio::sync::Mutex;
|
||||||
use tokio::sync::{broadcast, Mutex};
|
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use super::AuthState;
|
use super::AuthState;
|
||||||
use crate::{AuthResult, ConfigProvider, WarpgateError};
|
use crate::{ConfigProvider, WarpgateError};
|
||||||
|
|
||||||
#[allow(clippy::unwrap_used)]
|
|
||||||
pub static TIMEOUT: Lazy<Duration> = Lazy::new(|| Duration::from_secs(60 * 10));
|
|
||||||
|
|
||||||
struct AuthCompletionSignal {
|
|
||||||
sender: broadcast::Sender<AuthResult>,
|
|
||||||
created_at: Instant,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AuthCompletionSignal {
|
|
||||||
pub fn is_expired(&self) -> bool {
|
|
||||||
self.created_at.elapsed() > *TIMEOUT
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct AuthStateStore {
|
pub struct AuthStateStore {
|
||||||
config_provider: Arc<Mutex<dyn ConfigProvider + Send + 'static>>,
|
config_provider: Arc<Mutex<dyn ConfigProvider + Send + 'static>>,
|
||||||
store: HashMap<Uuid, (Arc<Mutex<AuthState>>, Instant)>,
|
store: HashMap<Uuid, AuthState>,
|
||||||
completion_signals: HashMap<Uuid, AuthCompletionSignal>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AuthStateStore {
|
impl AuthStateStore {
|
||||||
|
@ -34,66 +17,47 @@ impl AuthStateStore {
|
||||||
Self {
|
Self {
|
||||||
store: HashMap::new(),
|
store: HashMap::new(),
|
||||||
config_provider,
|
config_provider,
|
||||||
completion_signals: HashMap::new(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn contains_key(&self, id: &Uuid) -> bool {
|
pub fn contains_key(&mut self, id: &Uuid) -> bool {
|
||||||
self.store.contains_key(id)
|
self.store.contains_key(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, id: &Uuid) -> Option<Arc<Mutex<AuthState>>> {
|
pub fn get_mut(&mut self, id: &Uuid) -> Option<&mut AuthState> {
|
||||||
self.store.get(id).map(|x| x.0.clone())
|
self.store.get_mut(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create(
|
pub async fn create(
|
||||||
&mut self,
|
&mut self,
|
||||||
username: &str,
|
username: &str,
|
||||||
protocol: &str,
|
protocol: &str,
|
||||||
) -> Result<(Uuid, Arc<Mutex<AuthState>>), WarpgateError> {
|
) -> Result<(Uuid, &mut AuthState), WarpgateError> {
|
||||||
let id = Uuid::new_v4();
|
let id = Uuid::new_v4();
|
||||||
let Some(policy) = self.config_provider
|
let state = AuthState::new(
|
||||||
|
username.to_string(),
|
||||||
|
protocol.to_string(),
|
||||||
|
self.config_provider
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
.await
|
||||||
.get_credential_policy(username)
|
.get_credential_policy(username)
|
||||||
.await? else {
|
.await?,
|
||||||
return Err(WarpgateError::UserNotFound)
|
);
|
||||||
};
|
self.store.insert(id, state);
|
||||||
|
|
||||||
let state = AuthState::new(id, username.to_string(), protocol.to_string(), policy);
|
|
||||||
self.store
|
|
||||||
.insert(id, (Arc::new(Mutex::new(state)), Instant::now()));
|
|
||||||
|
|
||||||
#[allow(clippy::unwrap_used)]
|
#[allow(clippy::unwrap_used)]
|
||||||
Ok((id, self.get(&id).unwrap()))
|
Ok((id, self.store.get_mut(&id).unwrap()))
|
||||||
}
|
|
||||||
|
|
||||||
pub fn subscribe(&mut self, id: Uuid) -> broadcast::Receiver<AuthResult> {
|
|
||||||
let signal = self.completion_signals.entry(id).or_insert_with(|| {
|
|
||||||
let (sender, _) = broadcast::channel(1);
|
|
||||||
AuthCompletionSignal {
|
|
||||||
sender,
|
|
||||||
created_at: Instant::now(),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
signal.sender.subscribe()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn complete(&mut self, id: &Uuid) {
|
|
||||||
let Some((state, _)) = self.store.get(id) else {
|
|
||||||
return
|
|
||||||
};
|
|
||||||
if let Some(sig) = self.completion_signals.remove(id) {
|
|
||||||
let _ = sig.sender.send(state.lock().await.verify());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn vacuum(&mut self) {
|
pub async fn vacuum(&mut self) {
|
||||||
self.store
|
let mut to_remove = vec![];
|
||||||
.retain(|_, (_, started_at)| started_at.elapsed() < *TIMEOUT);
|
for (id, state) in self.store.iter() {
|
||||||
|
if state.is_expired() {
|
||||||
self.completion_signals
|
to_remove.push(*id);
|
||||||
.retain(|_, signal| !signal.is_expired());
|
}
|
||||||
|
}
|
||||||
|
for id in to_remove {
|
||||||
|
self.store.remove(&id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,12 +5,11 @@ use std::time::Duration;
|
||||||
|
|
||||||
use poem_openapi::{Enum, Object, Union};
|
use poem_openapi::{Enum, Object, Union};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use url::Url;
|
|
||||||
use warpgate_sso::SsoProviderConfig;
|
use warpgate_sso::SsoProviderConfig;
|
||||||
|
|
||||||
use crate::auth::CredentialKind;
|
use crate::auth::CredentialKind;
|
||||||
use crate::helpers::otp::OtpSecretKey;
|
use crate::helpers::otp::OtpSecretKey;
|
||||||
use crate::{ListenEndpoint, Secret, WarpgateError};
|
use crate::{ListenEndpoint, Secret};
|
||||||
|
|
||||||
const fn _default_true() -> bool {
|
const fn _default_true() -> bool {
|
||||||
true
|
true
|
||||||
|
@ -427,24 +426,3 @@ pub struct WarpgateConfig {
|
||||||
pub store: WarpgateConfigStore,
|
pub store: WarpgateConfigStore,
|
||||||
pub paths_relative_to: PathBuf,
|
pub paths_relative_to: PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WarpgateConfig {
|
|
||||||
pub fn construct_external_url(
|
|
||||||
&self,
|
|
||||||
fallback_host: Option<&str>,
|
|
||||||
) -> Result<Url, WarpgateError> {
|
|
||||||
let ext_host = self.store.external_host.as_deref().or(fallback_host);
|
|
||||||
let Some(ext_host) = ext_host else {
|
|
||||||
return Err(WarpgateError::ExternalHostNotSet);
|
|
||||||
};
|
|
||||||
let ext_port = self.store.http.listen.port();
|
|
||||||
|
|
||||||
let mut url = Url::parse(&format!("https://{ext_host}/"))?;
|
|
||||||
|
|
||||||
if ext_port != 443 {
|
|
||||||
let _ = url.set_port(Some(ext_port));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(url)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::HashSet;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
@ -11,10 +11,7 @@ use uuid::Uuid;
|
||||||
use warpgate_db_entities::Ticket;
|
use warpgate_db_entities::Ticket;
|
||||||
|
|
||||||
use super::ConfigProvider;
|
use super::ConfigProvider;
|
||||||
use crate::auth::{
|
use crate::auth::{AuthCredential, CredentialPolicy};
|
||||||
AllCredentialsPolicy, AnySingleCredentialPolicy, AuthCredential, CredentialKind,
|
|
||||||
CredentialPolicy, PerProtocolCredentialPolicy,
|
|
||||||
};
|
|
||||||
use crate::helpers::hash::verify_password_hash;
|
use crate::helpers::hash::verify_password_hash;
|
||||||
use crate::helpers::otp::verify_totp;
|
use crate::helpers::otp::verify_totp;
|
||||||
use crate::{Target, User, UserAuthCredential, UserSnapshot, WarpgateConfig, WarpgateError};
|
use crate::{Target, User, UserAuthCredential, UserSnapshot, WarpgateConfig, WarpgateError};
|
||||||
|
@ -81,52 +78,9 @@ impl ConfigProvider for FileConfigProvider {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
||||||
let supported_credential_types: HashSet<CredentialKind> =
|
Ok(user
|
||||||
user.credentials.iter().map(|x| x.kind()).collect();
|
.require
|
||||||
let default_policy = Box::new(AnySingleCredentialPolicy {
|
.map(|r| Box::new(r) as Box<dyn CredentialPolicy + Sync + Send>))
|
||||||
supported_credential_types: supported_credential_types.clone(),
|
|
||||||
}) as Box<dyn CredentialPolicy + Sync + Send>;
|
|
||||||
|
|
||||||
if let Some(req) = user.require {
|
|
||||||
let mut policy = PerProtocolCredentialPolicy {
|
|
||||||
default: default_policy,
|
|
||||||
protocols: HashMap::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(p) = req.http {
|
|
||||||
policy.protocols.insert(
|
|
||||||
"HTTP",
|
|
||||||
Box::new(AllCredentialsPolicy {
|
|
||||||
supported_credential_types: supported_credential_types.clone(),
|
|
||||||
required_credential_types: p.into_iter().collect(),
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if let Some(p) = req.mysql {
|
|
||||||
policy.protocols.insert(
|
|
||||||
"MySQL",
|
|
||||||
Box::new(AllCredentialsPolicy {
|
|
||||||
supported_credential_types: supported_credential_types.clone(),
|
|
||||||
required_credential_types: p.into_iter().collect(),
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if let Some(p) = req.ssh {
|
|
||||||
policy.protocols.insert(
|
|
||||||
"SSH",
|
|
||||||
Box::new(AllCredentialsPolicy {
|
|
||||||
supported_credential_types,
|
|
||||||
required_credential_types: p.into_iter().collect(),
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Some(
|
|
||||||
Box::new(policy) as Box<dyn CredentialPolicy + Sync + Send>
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
Ok(Some(default_policy))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn username_for_sso_credential(
|
async fn username_for_sso_credential(
|
||||||
|
@ -239,7 +193,6 @@ impl ConfigProvider for FileConfigProvider {
|
||||||
}
|
}
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
_ => return Err(WarpgateError::InvalidCredentialType),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
mod file;
|
mod file;
|
||||||
use std::collections::HashSet;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
@ -13,10 +12,11 @@ use warpgate_db_entities::Ticket;
|
||||||
use crate::auth::{AuthCredential, CredentialKind, CredentialPolicy};
|
use crate::auth::{AuthCredential, CredentialKind, CredentialPolicy};
|
||||||
use crate::{Secret, Target, UserSnapshot, WarpgateError};
|
use crate::{Secret, Target, UserSnapshot, WarpgateError};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug)]
|
||||||
pub enum AuthResult {
|
pub enum AuthResult {
|
||||||
Accepted { username: String },
|
Accepted { username: String },
|
||||||
Need(HashSet<CredentialKind>),
|
Need(CredentialKind),
|
||||||
|
NeedMoreCredentials,
|
||||||
Rejected,
|
Rejected,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,16 +9,8 @@ pub enum WarpgateError {
|
||||||
DatabaseError(#[from] sea_orm::DbErr),
|
DatabaseError(#[from] sea_orm::DbErr),
|
||||||
#[error("ticket not found: {0}")]
|
#[error("ticket not found: {0}")]
|
||||||
InvalidTicket(Uuid),
|
InvalidTicket(Uuid),
|
||||||
#[error("invalid credential type")]
|
|
||||||
InvalidCredentialType,
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Other(Box<dyn Error + Send + Sync>),
|
Other(Box<dyn Error + Send + Sync>),
|
||||||
#[error("user not found")]
|
|
||||||
UserNotFound,
|
|
||||||
#[error("failed to parse url: {0}")]
|
|
||||||
UrlParse(#[from] url::ParseError),
|
|
||||||
#[error("external_url config option is not set")]
|
|
||||||
ExternalHostNotSet,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResponseError for WarpgateError {
|
impl ResponseError for WarpgateError {
|
||||||
|
|
|
@ -7,7 +7,6 @@ version = "0.4.0"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
chrono = {version = "0.4", features = ["serde"]}
|
|
||||||
cookie = "0.16"
|
cookie = "0.16"
|
||||||
data-encoding = "2.3"
|
data-encoding = "2.3"
|
||||||
delegate = "0.6"
|
delegate = "0.6"
|
||||||
|
|
|
@ -3,18 +3,14 @@ use std::sync::Arc;
|
||||||
use poem::session::Session;
|
use poem::session::Session;
|
||||||
use poem::web::Data;
|
use poem::web::Data;
|
||||||
use poem::Request;
|
use poem::Request;
|
||||||
use poem_openapi::param::Path;
|
|
||||||
use poem_openapi::payload::Json;
|
use poem_openapi::payload::Json;
|
||||||
use poem_openapi::{ApiResponse, Enum, Object, OpenApi};
|
use poem_openapi::{ApiResponse, Enum, Object, OpenApi};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tracing::*;
|
use tracing::*;
|
||||||
use uuid::Uuid;
|
use warpgate_common::auth::{AuthCredential, CredentialKind};
|
||||||
use warpgate_common::auth::{AuthCredential, AuthState, CredentialKind};
|
|
||||||
use warpgate_common::{AuthResult, Secret, Services};
|
use warpgate_common::{AuthResult, Secret, Services};
|
||||||
|
|
||||||
use crate::common::{
|
use crate::common::{authorize_session, get_auth_state_for_request, SessionExt};
|
||||||
authorize_session, endpoint_auth, get_auth_state_for_request, SessionAuthorization, SessionExt,
|
|
||||||
};
|
|
||||||
use crate::session::SessionStore;
|
use crate::session::SessionStore;
|
||||||
|
|
||||||
pub struct Api;
|
pub struct Api;
|
||||||
|
@ -37,8 +33,6 @@ enum ApiAuthState {
|
||||||
PasswordNeeded,
|
PasswordNeeded,
|
||||||
OtpNeeded,
|
OtpNeeded,
|
||||||
SsoNeeded,
|
SsoNeeded,
|
||||||
WebUserApprovalNeeded,
|
|
||||||
PublicKeyNeeded,
|
|
||||||
Success,
|
Success,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +58,6 @@ enum LogoutResponse {
|
||||||
|
|
||||||
#[derive(Object)]
|
#[derive(Object)]
|
||||||
struct AuthStateResponseInternal {
|
struct AuthStateResponseInternal {
|
||||||
pub protocol: String,
|
|
||||||
pub state: ApiAuthState,
|
pub state: ApiAuthState,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,22 +65,17 @@ struct AuthStateResponseInternal {
|
||||||
enum AuthStateResponse {
|
enum AuthStateResponse {
|
||||||
#[oai(status = 200)]
|
#[oai(status = 200)]
|
||||||
Ok(Json<AuthStateResponseInternal>),
|
Ok(Json<AuthStateResponseInternal>),
|
||||||
#[oai(status = 404)]
|
|
||||||
NotFound,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<AuthResult> for ApiAuthState {
|
impl From<AuthResult> for ApiAuthState {
|
||||||
fn from(state: AuthResult) -> Self {
|
fn from(state: AuthResult) -> Self {
|
||||||
match state {
|
match state {
|
||||||
AuthResult::Rejected => ApiAuthState::Failed,
|
AuthResult::Rejected => ApiAuthState::Failed,
|
||||||
AuthResult::Need(kinds) => match kinds.iter().next() {
|
AuthResult::Need(CredentialKind::Password) => ApiAuthState::PasswordNeeded,
|
||||||
Some(CredentialKind::Password) => ApiAuthState::PasswordNeeded,
|
AuthResult::Need(CredentialKind::Otp) => ApiAuthState::OtpNeeded,
|
||||||
Some(CredentialKind::Otp) => ApiAuthState::OtpNeeded,
|
AuthResult::Need(CredentialKind::Sso) => ApiAuthState::SsoNeeded,
|
||||||
Some(CredentialKind::Sso) => ApiAuthState::SsoNeeded,
|
AuthResult::Need(CredentialKind::PublicKey) => ApiAuthState::Failed,
|
||||||
Some(CredentialKind::WebUserApproval) => ApiAuthState::WebUserApprovalNeeded,
|
AuthResult::NeedMoreCredentials => ApiAuthState::Failed,
|
||||||
Some(CredentialKind::PublicKey) => ApiAuthState::PublicKeyNeeded,
|
|
||||||
None => ApiAuthState::Failed,
|
|
||||||
},
|
|
||||||
AuthResult::Accepted { .. } => ApiAuthState::Success,
|
AuthResult::Accepted { .. } => ApiAuthState::Success,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -104,9 +92,8 @@ impl Api {
|
||||||
body: Json<LoginRequest>,
|
body: Json<LoginRequest>,
|
||||||
) -> poem::Result<LoginResponse> {
|
) -> poem::Result<LoginResponse> {
|
||||||
let mut auth_state_store = services.auth_state_store.lock().await;
|
let mut auth_state_store = services.auth_state_store.lock().await;
|
||||||
let state_arc =
|
let state =
|
||||||
get_auth_state_for_request(&body.username, session, &mut auth_state_store).await?;
|
get_auth_state_for_request(&body.username, session, &mut auth_state_store).await?;
|
||||||
let mut state = state_arc.lock().await;
|
|
||||||
|
|
||||||
let mut cp = services.config_provider.lock().await;
|
let mut cp = services.config_provider.lock().await;
|
||||||
|
|
||||||
|
@ -120,7 +107,6 @@ impl Api {
|
||||||
|
|
||||||
match state.verify() {
|
match state.verify() {
|
||||||
AuthResult::Accepted { username } => {
|
AuthResult::Accepted { username } => {
|
||||||
auth_state_store.complete(state.id()).await;
|
|
||||||
authorize_session(req, username).await?;
|
authorize_session(req, username).await?;
|
||||||
Ok(LoginResponse::Success)
|
Ok(LoginResponse::Success)
|
||||||
}
|
}
|
||||||
|
@ -145,14 +131,12 @@ impl Api {
|
||||||
|
|
||||||
let mut auth_state_store = services.auth_state_store.lock().await;
|
let mut auth_state_store = services.auth_state_store.lock().await;
|
||||||
|
|
||||||
let Some(state_arc) = state_id.and_then(|id| auth_state_store.get(&id.0)) else {
|
let Some(state) = state_id.and_then(|id| auth_state_store.get_mut(&id.0)) else {
|
||||||
return Ok(LoginResponse::Failure(Json(LoginFailureResponse {
|
return Ok(LoginResponse::Failure(Json(LoginFailureResponse {
|
||||||
state: ApiAuthState::NotStarted,
|
state: ApiAuthState::NotStarted,
|
||||||
})))
|
})))
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut state = state_arc.lock().await;
|
|
||||||
|
|
||||||
let mut cp = services.config_provider.lock().await;
|
let mut cp = services.config_provider.lock().await;
|
||||||
|
|
||||||
let otp_cred = AuthCredential::Otp(body.otp.clone().into());
|
let otp_cred = AuthCredential::Otp(body.otp.clone().into());
|
||||||
|
@ -162,7 +146,6 @@ impl Api {
|
||||||
|
|
||||||
match state.verify() {
|
match state.verify() {
|
||||||
AuthResult::Accepted { username } => {
|
AuthResult::Accepted { username } => {
|
||||||
auth_state_store.complete(state.id()).await;
|
|
||||||
authorize_session(req, username).await?;
|
authorize_session(req, username).await?;
|
||||||
Ok(LoginResponse::Success)
|
Ok(LoginResponse::Success)
|
||||||
}
|
}
|
||||||
|
@ -172,6 +155,27 @@ impl Api {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[oai(path = "/auth/state", method = "get", operation_id = "getAuthState")]
|
||||||
|
async fn api_auth_state(
|
||||||
|
&self,
|
||||||
|
session: &Session,
|
||||||
|
services: Data<&Services>,
|
||||||
|
) -> poem::Result<AuthStateResponse> {
|
||||||
|
let state_id = session.get_auth_state_id();
|
||||||
|
|
||||||
|
let mut auth_state_store = services.auth_state_store.lock().await;
|
||||||
|
|
||||||
|
let Some(state) = state_id.and_then(|id| auth_state_store.get_mut(&id.0)) else {
|
||||||
|
return Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal {
|
||||||
|
state: ApiAuthState::NotStarted,
|
||||||
|
})));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal {
|
||||||
|
state: state.verify().into(),
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
#[oai(path = "/auth/logout", method = "post", operation_id = "logout")]
|
#[oai(path = "/auth/logout", method = "post", operation_id = "logout")]
|
||||||
async fn api_auth_logout(
|
async fn api_auth_logout(
|
||||||
&self,
|
&self,
|
||||||
|
@ -183,129 +187,4 @@ impl Api {
|
||||||
info!("Logged out");
|
info!("Logged out");
|
||||||
Ok(LogoutResponse::Success)
|
Ok(LogoutResponse::Success)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[oai(
|
|
||||||
path = "/auth/state",
|
|
||||||
method = "get",
|
|
||||||
operation_id = "getDefaultAuthState"
|
|
||||||
)]
|
|
||||||
async fn api_default_auth_state(
|
|
||||||
&self,
|
|
||||||
session: &Session,
|
|
||||||
services: Data<&Services>,
|
|
||||||
) -> poem::Result<AuthStateResponse> {
|
|
||||||
let Some(state_id) = session.get_auth_state_id() else {
|
|
||||||
return Ok(AuthStateResponse::NotFound)
|
|
||||||
};
|
|
||||||
let store = services.auth_state_store.lock().await;
|
|
||||||
let Some(state_arc) = store.get(&state_id.0) else {
|
|
||||||
return Ok(AuthStateResponse::NotFound);
|
|
||||||
};
|
|
||||||
serialize_auth_state_inner(state_arc).await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[oai(
|
|
||||||
path = "/auth/state/:id",
|
|
||||||
method = "get",
|
|
||||||
operation_id = "get_auth_state",
|
|
||||||
transform = "endpoint_auth"
|
|
||||||
)]
|
|
||||||
async fn api_auth_state(
|
|
||||||
&self,
|
|
||||||
services: Data<&Services>,
|
|
||||||
auth: Option<Data<&SessionAuthorization>>,
|
|
||||||
id: Path<Uuid>,
|
|
||||||
) -> poem::Result<AuthStateResponse> {
|
|
||||||
let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else {
|
|
||||||
return Ok(AuthStateResponse::NotFound);
|
|
||||||
};
|
|
||||||
serialize_auth_state_inner(state_arc).await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[oai(
|
|
||||||
path = "/auth/state/:id/approve",
|
|
||||||
method = "post",
|
|
||||||
operation_id = "approve_auth",
|
|
||||||
transform = "endpoint_auth"
|
|
||||||
)]
|
|
||||||
async fn api_approve_auth(
|
|
||||||
&self,
|
|
||||||
services: Data<&Services>,
|
|
||||||
auth: Option<Data<&SessionAuthorization>>,
|
|
||||||
id: Path<Uuid>,
|
|
||||||
) -> poem::Result<AuthStateResponse> {
|
|
||||||
let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else {
|
|
||||||
return Ok(AuthStateResponse::NotFound);
|
|
||||||
};
|
|
||||||
|
|
||||||
let auth_result = {
|
|
||||||
let mut state = state_arc.lock().await;
|
|
||||||
state.add_valid_credential(AuthCredential::WebUserApproval);
|
|
||||||
state.verify()
|
|
||||||
};
|
|
||||||
|
|
||||||
if let AuthResult::Accepted { .. } = auth_result {
|
|
||||||
services.auth_state_store.lock().await.complete(&*id).await;
|
|
||||||
}
|
|
||||||
serialize_auth_state_inner(state_arc).await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[oai(
|
|
||||||
path = "/auth/state/:id/reject",
|
|
||||||
method = "post",
|
|
||||||
operation_id = "reject_auth",
|
|
||||||
transform = "endpoint_auth"
|
|
||||||
)]
|
|
||||||
async fn api_reject_auth(
|
|
||||||
&self,
|
|
||||||
services: Data<&Services>,
|
|
||||||
auth: Option<Data<&SessionAuthorization>>,
|
|
||||||
id: Path<Uuid>,
|
|
||||||
) -> poem::Result<AuthStateResponse> {
|
|
||||||
let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else {
|
|
||||||
return Ok(AuthStateResponse::NotFound);
|
|
||||||
};
|
|
||||||
state_arc.lock().await.reject();
|
|
||||||
services.auth_state_store.lock().await.complete(&*id).await;
|
|
||||||
serialize_auth_state_inner(state_arc).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_auth_state(
|
|
||||||
id: &Uuid,
|
|
||||||
services: &Services,
|
|
||||||
auth: Option<&SessionAuthorization>,
|
|
||||||
) -> Option<Arc<Mutex<AuthState>>> {
|
|
||||||
let store = services.auth_state_store.lock().await;
|
|
||||||
|
|
||||||
let Some(auth) = auth else {
|
|
||||||
return None;
|
|
||||||
};
|
|
||||||
|
|
||||||
let SessionAuthorization::User(username) = auth else {
|
|
||||||
return None;
|
|
||||||
};
|
|
||||||
|
|
||||||
let Some(state_arc) = store.get(&*id) else {
|
|
||||||
return None;
|
|
||||||
};
|
|
||||||
|
|
||||||
{
|
|
||||||
let state = state_arc.lock().await;
|
|
||||||
if state.username() != username {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(state_arc)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn serialize_auth_state_inner(
|
|
||||||
state_arc: Arc<Mutex<AuthState>>,
|
|
||||||
) -> poem::Result<AuthStateResponse> {
|
|
||||||
let state = state_arc.lock().await;
|
|
||||||
Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal {
|
|
||||||
protocol: state.protocol().to_string(),
|
|
||||||
state: state.verify().into(),
|
|
||||||
})))
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
use poem::session::Session;
|
use poem::session::Session;
|
||||||
use poem::web::Data;
|
use poem::web::Data;
|
||||||
use poem::Request;
|
use poem::Request;
|
||||||
use poem_openapi::param::{Path, Query};
|
use poem_openapi::param::Path;
|
||||||
use poem_openapi::payload::Json;
|
use poem_openapi::payload::Json;
|
||||||
use poem_openapi::{ApiResponse, Object, OpenApi};
|
use poem_openapi::{ApiResponse, Object, OpenApi};
|
||||||
|
use reqwest::Url;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use warpgate_common::Services;
|
use warpgate_common::Services;
|
||||||
use warpgate_sso::{SsoClient, SsoLoginRequest};
|
use warpgate_sso::{SsoClient, SsoLoginRequest};
|
||||||
|
@ -30,7 +31,6 @@ pub static SSO_CONTEXT_SESSION_KEY: &str = "sso_request";
|
||||||
pub struct SsoContext {
|
pub struct SsoContext {
|
||||||
pub provider: String,
|
pub provider: String,
|
||||||
pub request: SsoLoginRequest,
|
pub request: SsoLoginRequest,
|
||||||
pub next_url: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[OpenApi]
|
#[OpenApi]
|
||||||
|
@ -46,14 +46,31 @@ impl Api {
|
||||||
session: &Session,
|
session: &Session,
|
||||||
services: Data<&Services>,
|
services: Data<&Services>,
|
||||||
name: Path<String>,
|
name: Path<String>,
|
||||||
next: Query<Option<String>>,
|
|
||||||
) -> poem::Result<StartSsoResponse> {
|
) -> poem::Result<StartSsoResponse> {
|
||||||
let config = services.config.lock().await;
|
let config = services.config.lock().await;
|
||||||
|
|
||||||
let name = name.0;
|
let name = name.0;
|
||||||
|
let ext_host = config
|
||||||
|
.store
|
||||||
|
.external_host
|
||||||
|
.as_deref()
|
||||||
|
.or_else(|| req.original_uri().host());
|
||||||
|
let Some(ext_host) = ext_host else {
|
||||||
|
return Err(poem::Error::from_string("external_host config option is required for SSO", http::status::StatusCode::INTERNAL_SERVER_ERROR));
|
||||||
|
};
|
||||||
|
let ext_port = config.store.http.listen.port();
|
||||||
|
|
||||||
let mut return_url = config.construct_external_url(req.original_uri().host())?;
|
let mut return_url = Url::parse(&format!("https://{ext_host}/@warpgate/api/sso/return"))
|
||||||
return_url.set_path("@warpgate/api/sso/return");
|
.map_err(|e| {
|
||||||
|
poem::Error::from_string(
|
||||||
|
format!("failed to construct the return URL: {e}"),
|
||||||
|
http::status::StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if ext_port != 443 {
|
||||||
|
let _ = return_url.set_port(Some(ext_port));
|
||||||
|
}
|
||||||
|
|
||||||
let Some(provider_config) = config.store.sso_providers.iter().find(|p| p.name == *name) else {
|
let Some(provider_config) = config.store.sso_providers.iter().find(|p| p.name == *name) else {
|
||||||
return Ok(StartSsoResponse::NotFound);
|
return Ok(StartSsoResponse::NotFound);
|
||||||
|
@ -67,14 +84,10 @@ impl Api {
|
||||||
.map_err(poem::error::InternalServerError)?;
|
.map_err(poem::error::InternalServerError)?;
|
||||||
|
|
||||||
let url = sso_req.auth_url().to_string();
|
let url = sso_req.auth_url().to_string();
|
||||||
session.set(
|
session.set(SSO_CONTEXT_SESSION_KEY, SsoContext {
|
||||||
SSO_CONTEXT_SESSION_KEY,
|
|
||||||
SsoContext {
|
|
||||||
provider: name,
|
provider: name,
|
||||||
request: sso_req,
|
request: sso_req,
|
||||||
next_url: next.0.clone(),
|
});
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(StartSsoResponse::Ok(Json(StartSsoResponseParams { url })))
|
Ok(StartSsoResponse::Ok(Json(StartSsoResponseParams { url })))
|
||||||
}
|
}
|
||||||
|
|
|
@ -120,9 +120,8 @@ impl Api {
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut auth_state_store = services.auth_state_store.lock().await;
|
let mut auth_state_store = services.auth_state_store.lock().await;
|
||||||
let state_arc = get_auth_state_for_request(&username, session, &mut auth_state_store).await?;
|
let state = get_auth_state_for_request(&username, session, &mut auth_state_store).await?;
|
||||||
|
|
||||||
let mut state = state_arc.lock().await;
|
|
||||||
let mut cp = services.config_provider.lock().await;
|
let mut cp = services.config_provider.lock().await;
|
||||||
|
|
||||||
if cp.validate_credential(&username, &cred).await? {
|
if cp.validate_credential(&username, &cred).await? {
|
||||||
|
@ -131,15 +130,11 @@ impl Api {
|
||||||
|
|
||||||
match state.verify() {
|
match state.verify() {
|
||||||
AuthResult::Accepted { username } => {
|
AuthResult::Accepted { username } => {
|
||||||
auth_state_store.complete(state.id()).await;
|
|
||||||
authorize_session(req, username).await?;
|
authorize_session(req, username).await?;
|
||||||
}
|
}
|
||||||
_ => (),
|
_ => ()
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Response::new(ReturnToSsoResponse::Ok).header(
|
Ok(Response::new(ReturnToSsoResponse::Ok).header("Location", "/@warpgate"))
|
||||||
"Location",
|
|
||||||
context.next_url.as_deref().unwrap_or("/@warpgate#/login"),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,18 +170,18 @@ pub fn gateway_redirect(req: &Request) -> Response {
|
||||||
.unwrap_or("".into());
|
.unwrap_or("".into());
|
||||||
|
|
||||||
let path = format!(
|
let path = format!(
|
||||||
"/@warpgate#/login?next={}",
|
"/@warpgate?next={}",
|
||||||
utf8_percent_encode(&path, NON_ALPHANUMERIC),
|
utf8_percent_encode(&path, NON_ALPHANUMERIC),
|
||||||
);
|
);
|
||||||
|
|
||||||
Redirect::temporary(path).into_response()
|
Redirect::temporary(path).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_auth_state_for_request(
|
pub async fn get_auth_state_for_request<'a>(
|
||||||
username: &str,
|
username: &str,
|
||||||
session: &Session,
|
session: &Session,
|
||||||
store: &mut AuthStateStore,
|
store: &'a mut AuthStateStore,
|
||||||
) -> Result<Arc<Mutex<AuthState>>, WarpgateError> {
|
) -> Result<&'a mut AuthState, WarpgateError> {
|
||||||
match session.get_auth_state_id() {
|
match session.get_auth_state_id() {
|
||||||
Some(id) => {
|
Some(id) => {
|
||||||
if !store.contains_key(&id.0) {
|
if !store.contains_key(&id.0) {
|
||||||
|
@ -192,7 +192,7 @@ pub async fn get_auth_state_for_request(
|
||||||
};
|
};
|
||||||
|
|
||||||
match session.get_auth_state_id() {
|
match session.get_auth_state_id() {
|
||||||
Some(id) => Ok(store.get(&id.0).unwrap()),
|
Some(id) => Ok(store.get_mut(&id.0).unwrap()),
|
||||||
None => {
|
None => {
|
||||||
let (id, state) = store
|
let (id, state) = store
|
||||||
.create(&username, crate::common::PROTOCOL_NAME)
|
.create(&username, crate::common::PROTOCOL_NAME)
|
||||||
|
|
|
@ -7,7 +7,7 @@ use tokio::net::TcpStream;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tracing::*;
|
use tracing::*;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use warpgate_common::auth::{AuthCredential, AuthSelector};
|
use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState};
|
||||||
use warpgate_common::helpers::rng::get_crypto_rng;
|
use warpgate_common::helpers::rng::get_crypto_rng;
|
||||||
use warpgate_common::{
|
use warpgate_common::{
|
||||||
authorize_ticket, AuthResult, Secret, Services, TargetMySqlOptions, TargetOptions,
|
authorize_ticket, AuthResult, Secret, Services, TargetMySqlOptions, TargetOptions,
|
||||||
|
@ -180,20 +180,15 @@ impl MySqlSession {
|
||||||
username,
|
username,
|
||||||
target_name,
|
target_name,
|
||||||
} => {
|
} => {
|
||||||
let state_arc = self
|
|
||||||
.services
|
|
||||||
.auth_state_store
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.create(&username, crate::common::PROTOCOL_NAME)
|
|
||||||
.await?
|
|
||||||
.1;
|
|
||||||
let mut state = state_arc.lock().await;
|
|
||||||
|
|
||||||
let user_auth_result = {
|
let user_auth_result = {
|
||||||
let credential = AuthCredential::Password(password);
|
|
||||||
|
|
||||||
let mut cp = self.services.config_provider.lock().await;
|
let mut cp = self.services.config_provider.lock().await;
|
||||||
|
|
||||||
|
let credential = AuthCredential::Password(password);
|
||||||
|
let mut state = AuthState::new(
|
||||||
|
username.clone(),
|
||||||
|
crate::common::PROTOCOL_NAME.to_string(),
|
||||||
|
cp.get_credential_policy(&username).await?,
|
||||||
|
);
|
||||||
if cp.validate_credential(&username, &credential).await? {
|
if cp.validate_credential(&username, &credential).await? {
|
||||||
state.add_valid_credential(credential);
|
state.add_valid_credential(credential);
|
||||||
}
|
}
|
||||||
|
@ -203,12 +198,6 @@ impl MySqlSession {
|
||||||
|
|
||||||
match user_auth_result {
|
match user_auth_result {
|
||||||
AuthResult::Accepted { username } => {
|
AuthResult::Accepted { username } => {
|
||||||
self.services
|
|
||||||
.auth_state_store
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.complete(state.id())
|
|
||||||
.await;
|
|
||||||
let target_auth_result = {
|
let target_auth_result = {
|
||||||
self.services
|
self.services
|
||||||
.config_provider
|
.config_provider
|
||||||
|
@ -227,7 +216,9 @@ impl MySqlSession {
|
||||||
}
|
}
|
||||||
self.run_authorized(handshake, username, target_name).await
|
self.run_authorized(handshake, username, target_name).await
|
||||||
}
|
}
|
||||||
AuthResult::Rejected | AuthResult::Need(_) => fail(&mut self).await, // TODO SSO
|
AuthResult::Rejected
|
||||||
|
| AuthResult::Need(_)
|
||||||
|
| AuthResult::NeedMoreCredentials => fail(&mut self).await, // TODO SSO
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AuthSelector::Ticket { secret } => {
|
AuthSelector::Ticket { secret } => {
|
||||||
|
|
|
@ -12,7 +12,7 @@ bimap = "0.6"
|
||||||
bytes = "1.2"
|
bytes = "1.2"
|
||||||
dialoguer = "0.10"
|
dialoguer = "0.10"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
russh = {version = "0.34.0-beta.8", features = ["openssl"]}
|
russh = {version = "0.34.0-beta.7", features = ["openssl"]}
|
||||||
russh-keys = {version = "0.22.0-beta.4", features = ["openssl"]}
|
russh-keys = {version = "0.22.0-beta.4", features = ["openssl"]}
|
||||||
sea-orm = {version = "^0.9", features = ["runtime-tokio-native-tls"], default-features = false}
|
sea-orm = {version = "^0.9", features = ["runtime-tokio-native-tls"], default-features = false}
|
||||||
thiserror = "1.0"
|
thiserror = "1.0"
|
||||||
|
|
|
@ -23,7 +23,6 @@ pub async fn run_server(services: Services, address: SocketAddr) -> Result<()> {
|
||||||
let config = services.config.lock().await;
|
let config = services.config.lock().await;
|
||||||
russh::server::Config {
|
russh::server::Config {
|
||||||
auth_rejection_time: std::time::Duration::from_secs(1),
|
auth_rejection_time: std::time::Duration::from_secs(1),
|
||||||
connection_timeout: Some(std::time::Duration::from_secs(300)),
|
|
||||||
methods: MethodSet::PUBLICKEY | MethodSet::PASSWORD | MethodSet::KEYBOARD_INTERACTIVE,
|
methods: MethodSet::PUBLICKEY | MethodSet::PASSWORD | MethodSet::KEYBOARD_INTERACTIVE,
|
||||||
keys: load_host_keys(&config)?,
|
keys: load_host_keys(&config)?,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::collections::hash_map::Entry::Vacant;
|
use std::collections::hash_map::Entry::Vacant;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::HashMap;
|
||||||
use std::net::{Ipv4Addr, SocketAddr};
|
use std::net::{Ipv4Addr, SocketAddr};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -10,11 +10,11 @@ use anyhow::{Context, Result};
|
||||||
use bimap::BiMap;
|
use bimap::BiMap;
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
use russh::server::Session;
|
use russh::server::Session;
|
||||||
use russh::{CryptoVec, MethodSet, Sig};
|
use russh::{CryptoVec, Sig};
|
||||||
use russh_keys::key::PublicKey;
|
use russh_keys::key::PublicKey;
|
||||||
use russh_keys::PublicKeyBase64;
|
use russh_keys::PublicKeyBase64;
|
||||||
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
||||||
use tokio::sync::{broadcast, oneshot, Mutex};
|
use tokio::sync::{oneshot, Mutex};
|
||||||
use tracing::*;
|
use tracing::*;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState, CredentialKind};
|
use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState, CredentialKind};
|
||||||
|
@ -53,12 +53,6 @@ enum Event {
|
||||||
Client(RCEvent),
|
Client(RCEvent),
|
||||||
}
|
}
|
||||||
|
|
||||||
enum KeyboardInteractiveState {
|
|
||||||
None,
|
|
||||||
OtpRequested,
|
|
||||||
WebAuthRequested(broadcast::Receiver<AuthResult>),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ServerSession {
|
pub struct ServerSession {
|
||||||
pub id: SessionId,
|
pub id: SessionId,
|
||||||
username: Option<String>,
|
username: Option<String>,
|
||||||
|
@ -80,8 +74,7 @@ pub struct ServerSession {
|
||||||
hub: EventHub<Event>,
|
hub: EventHub<Event>,
|
||||||
event_sender: EventSender<Event>,
|
event_sender: EventSender<Event>,
|
||||||
service_output: ServiceOutput,
|
service_output: ServiceOutput,
|
||||||
auth_state: Option<Arc<Mutex<AuthState>>>,
|
auth_state: Option<AuthState>,
|
||||||
keyboard_interactive_state: KeyboardInteractiveState,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn session_debug_tag(id: &SessionId, remote_address: &SocketAddr) -> String {
|
fn session_debug_tag(id: &SessionId, remote_address: &SocketAddr) -> String {
|
||||||
|
@ -149,7 +142,6 @@ impl ServerSession {
|
||||||
so_tx.send(BytesMut::from(data).freeze()).context("x")
|
so_tx.send(BytesMut::from(data).freeze()).context("x")
|
||||||
})),
|
})),
|
||||||
auth_state: None,
|
auth_state: None,
|
||||||
keyboard_interactive_state: KeyboardInteractiveState::None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let this = Arc::new(Mutex::new(this));
|
let this = Arc::new(Mutex::new(this));
|
||||||
|
@ -225,23 +217,18 @@ impl ServerSession {
|
||||||
Ok(this)
|
Ok(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_auth_state(&mut self, username: &str) -> Result<Arc<Mutex<AuthState>>> {
|
async fn get_auth_state(&mut self, username: &str) -> Result<&mut AuthState> {
|
||||||
#[allow(clippy::unwrap_used)]
|
#[allow(clippy::unwrap_used)]
|
||||||
if self.auth_state.is_none()
|
if self.auth_state.is_none() || self.auth_state.as_ref().unwrap().username() != username {
|
||||||
|| self.auth_state.as_ref().unwrap().lock().await.username() != username
|
let mut cp = self.services.config_provider.lock().await;
|
||||||
{
|
self.auth_state = Some(AuthState::new(
|
||||||
let state = self
|
username.to_string(),
|
||||||
.services
|
crate::PROTOCOL_NAME.to_string(),
|
||||||
.auth_state_store
|
cp.get_credential_policy(username).await?,
|
||||||
.lock()
|
));
|
||||||
.await
|
|
||||||
.create(username, crate::PROTOCOL_NAME)
|
|
||||||
.await?
|
|
||||||
.1;
|
|
||||||
self.auth_state = Some(state);
|
|
||||||
}
|
}
|
||||||
#[allow(clippy::unwrap_used)]
|
#[allow(clippy::unwrap_used)]
|
||||||
Ok(self.auth_state.as_ref().map(Clone::clone).unwrap())
|
Ok(self.auth_state.as_mut().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn make_logging_span(&self) -> tracing::Span {
|
pub fn make_logging_span(&self) -> tracing::Span {
|
||||||
|
@ -952,17 +939,13 @@ impl ServerSession {
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
||||||
Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
|
Ok(AuthResult::Rejected) => russh::server::Auth::Reject,
|
||||||
proceed_with_methods: Some(MethodSet::all()),
|
Ok(AuthResult::Need(_) | AuthResult::NeedMoreCredentials) => {
|
||||||
},
|
russh::server::Auth::Reject
|
||||||
Ok(AuthResult::Need(kinds)) => russh::server::Auth::Reject {
|
}
|
||||||
proceed_with_methods: Some(self.get_remaining_auth_methods(kinds)),
|
|
||||||
},
|
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
error!(?error, "Failed to verify credentials");
|
error!(?error, "Failed to verify credentials");
|
||||||
russh::server::Auth::Reject {
|
russh::server::Auth::Reject
|
||||||
proceed_with_methods: None,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -980,17 +963,13 @@ impl ServerSession {
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
||||||
Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
|
Ok(AuthResult::Rejected) => russh::server::Auth::Reject,
|
||||||
proceed_with_methods: None,
|
Ok(AuthResult::Need(_) | AuthResult::NeedMoreCredentials) => {
|
||||||
},
|
russh::server::Auth::Reject
|
||||||
Ok(AuthResult::Need(_)) => russh::server::Auth::Reject {
|
}
|
||||||
proceed_with_methods: None,
|
|
||||||
},
|
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
error!(?error, "Failed to verify credentials");
|
error!(?error, "Failed to verify credentials");
|
||||||
russh::server::Auth::Reject {
|
russh::server::Auth::Reject
|
||||||
proceed_with_methods: None,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1003,110 +982,26 @@ impl ServerSession {
|
||||||
let selector: AuthSelector = ssh_username.expose_secret().into();
|
let selector: AuthSelector = ssh_username.expose_secret().into();
|
||||||
info!("Keyboard-interactive auth as {:?}", selector);
|
info!("Keyboard-interactive auth as {:?}", selector);
|
||||||
|
|
||||||
let cred;
|
let cred = response.map(AuthCredential::Otp);
|
||||||
match &mut self.keyboard_interactive_state {
|
|
||||||
KeyboardInteractiveState::None => {
|
|
||||||
cred = None;
|
|
||||||
}
|
|
||||||
KeyboardInteractiveState::OtpRequested => {
|
|
||||||
cred = response.map(AuthCredential::Otp);
|
|
||||||
}
|
|
||||||
KeyboardInteractiveState::WebAuthRequested(event) => {
|
|
||||||
cred = None;
|
|
||||||
let _ = event.recv().await;
|
|
||||||
// the auth state has been updated by now
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.keyboard_interactive_state = KeyboardInteractiveState::None;
|
|
||||||
|
|
||||||
match self.try_auth(&selector, cred).await {
|
match self.try_auth(&selector, cred).await {
|
||||||
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
||||||
Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
|
Ok(
|
||||||
proceed_with_methods: None,
|
AuthResult::Rejected
|
||||||
},
|
| AuthResult::NeedMoreCredentials
|
||||||
Ok(AuthResult::Need(kinds)) => {
|
| AuthResult::Need(CredentialKind::Otp),
|
||||||
if kinds.contains(&CredentialKind::Otp) {
|
) => russh::server::Auth::Partial {
|
||||||
self.keyboard_interactive_state = KeyboardInteractiveState::OtpRequested;
|
|
||||||
russh::server::Auth::Partial {
|
|
||||||
name: Cow::Borrowed("Two-factor authentication"),
|
name: Cow::Borrowed("Two-factor authentication"),
|
||||||
instructions: Cow::Borrowed(""),
|
instructions: Cow::Borrowed(""),
|
||||||
prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]),
|
prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]),
|
||||||
}
|
},
|
||||||
} else if kinds.contains(&CredentialKind::WebUserApproval) {
|
Ok(AuthResult::Need(_)) => russh::server::Auth::Reject, // TODO SSO
|
||||||
let Some(auth_state) = self.auth_state.as_ref() else {
|
|
||||||
return russh::server::Auth::Reject { proceed_with_methods: None};
|
|
||||||
};
|
|
||||||
let auth_state_id = *auth_state.lock().await.id();
|
|
||||||
let event = self
|
|
||||||
.services
|
|
||||||
.auth_state_store
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.subscribe(auth_state_id);
|
|
||||||
self.keyboard_interactive_state =
|
|
||||||
KeyboardInteractiveState::WebAuthRequested(event);
|
|
||||||
|
|
||||||
let mut login_url = match self
|
|
||||||
.services
|
|
||||||
.config
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.construct_external_url(None)
|
|
||||||
{
|
|
||||||
Ok(url) => url,
|
|
||||||
Err(error) => {
|
|
||||||
error!(?error, "Failed to construct external URL");
|
|
||||||
return russh::server::Auth::Reject {
|
|
||||||
proceed_with_methods: None,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
login_url.set_path("@warpgate");
|
|
||||||
login_url
|
|
||||||
.set_fragment(Some(&format!("/login?next=%2Flogin%2F{auth_state_id}")));
|
|
||||||
|
|
||||||
russh::server::Auth::Partial {
|
|
||||||
name: Cow::Owned(format!(
|
|
||||||
concat!(
|
|
||||||
"----------------------------------------------------------------\n",
|
|
||||||
"Warpgate authentication: please open {} in your browser\n",
|
|
||||||
"----------------------------------------------------------------\n"
|
|
||||||
),
|
|
||||||
login_url
|
|
||||||
)),
|
|
||||||
instructions: Cow::Borrowed(""),
|
|
||||||
prompts: Cow::Owned(vec![(Cow::Borrowed("Press Enter when done: "), true)]),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
russh::server::Auth::Reject {
|
|
||||||
proceed_with_methods: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
error!(?error, "Failed to verify credentials");
|
error!(?error, "Failed to verify credentials");
|
||||||
russh::server::Auth::Reject {
|
russh::server::Auth::Reject
|
||||||
proceed_with_methods: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
fn get_remaining_auth_methods(&self, kinds: HashSet<CredentialKind>) -> MethodSet {
|
|
||||||
let mut m = MethodSet::empty();
|
|
||||||
for kind in kinds {
|
|
||||||
match kind {
|
|
||||||
CredentialKind::Password => m.insert(MethodSet::PASSWORD),
|
|
||||||
CredentialKind::Otp => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
|
|
||||||
CredentialKind::WebUserApproval => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
|
|
||||||
CredentialKind::PublicKey => m.insert(MethodSet::PUBLICKEY),
|
|
||||||
CredentialKind::Sso => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn try_auth(
|
async fn try_auth(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
@ -1119,9 +1014,7 @@ impl ServerSession {
|
||||||
target_name,
|
target_name,
|
||||||
} => {
|
} => {
|
||||||
let cp = self.services.config_provider.clone();
|
let cp = self.services.config_provider.clone();
|
||||||
|
let state = self.get_auth_state(username).await?;
|
||||||
let state_arc = self.get_auth_state(username).await?;
|
|
||||||
let mut state = state_arc.lock().await;
|
|
||||||
|
|
||||||
if let Some(credential) = credential {
|
if let Some(credential) = credential {
|
||||||
if cp
|
if cp
|
||||||
|
@ -1138,12 +1031,6 @@ impl ServerSession {
|
||||||
|
|
||||||
match user_auth_result {
|
match user_auth_result {
|
||||||
AuthResult::Accepted { username } => {
|
AuthResult::Accepted { username } => {
|
||||||
self.services
|
|
||||||
.auth_state_store
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.complete(state.id())
|
|
||||||
.await;
|
|
||||||
let target_auth_result = {
|
let target_auth_result = {
|
||||||
self.services
|
self.services
|
||||||
.config_provider
|
.config_provider
|
||||||
|
|
|
@ -2,25 +2,22 @@
|
||||||
import { faSignOut } from '@fortawesome/free-solid-svg-icons'
|
import { faSignOut } from '@fortawesome/free-solid-svg-icons'
|
||||||
import { Alert, Spinner } from 'sveltestrap'
|
import { Alert, Spinner } from 'sveltestrap'
|
||||||
import Fa from 'svelte-fa'
|
import Fa from 'svelte-fa'
|
||||||
import Router, { push } from 'svelte-spa-router'
|
|
||||||
import { wrap } from 'svelte-spa-router/wrap'
|
|
||||||
import { get } from 'svelte/store'
|
|
||||||
import { api } from 'gateway/lib/api'
|
import { api } from 'gateway/lib/api'
|
||||||
import { reloadServerInfo, serverInfo } from 'gateway/lib/store'
|
import { reloadServerInfo, serverInfo } from 'gateway/lib/store'
|
||||||
import ThemeSwitcher from 'common/ThemeSwitcher.svelte'
|
import ThemeSwitcher from 'common/ThemeSwitcher.svelte'
|
||||||
|
import Login from './Login.svelte'
|
||||||
|
import TargetList from './TargetList.svelte'
|
||||||
import Logo from 'common/Logo.svelte'
|
import Logo from 'common/Logo.svelte'
|
||||||
|
|
||||||
let redirecting = false
|
let redirecting = false
|
||||||
let serverInfoPromise = reloadServerInfo()
|
|
||||||
|
|
||||||
async function init () {
|
async function init () {
|
||||||
await serverInfoPromise
|
await reloadServerInfo()
|
||||||
}
|
}
|
||||||
|
|
||||||
async function logout () {
|
async function logout () {
|
||||||
await api.logout()
|
await api.logout()
|
||||||
await reloadServerInfo()
|
await reloadServerInfo()
|
||||||
push('/login')
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function onPageResume () {
|
function onPageResume () {
|
||||||
|
@ -28,36 +25,6 @@ function onPageResume () {
|
||||||
init()
|
init()
|
||||||
}
|
}
|
||||||
|
|
||||||
async function requireLogin (detail) {
|
|
||||||
await serverInfoPromise
|
|
||||||
if (!get(serverInfo)?.username) {
|
|
||||||
let url = detail.location
|
|
||||||
if (detail.querystring) {
|
|
||||||
url += '?' + detail.querystring
|
|
||||||
}
|
|
||||||
push('/login?next=' + encodeURIComponent(url))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
const routes = {
|
|
||||||
'/': wrap({
|
|
||||||
asyncComponent: () => import('./TargetList.svelte'),
|
|
||||||
props: {
|
|
||||||
'on:navigation': () => redirecting = true,
|
|
||||||
},
|
|
||||||
conditions: [requireLogin],
|
|
||||||
}),
|
|
||||||
'/login': wrap({
|
|
||||||
asyncComponent: () => import('./Login.svelte'),
|
|
||||||
}),
|
|
||||||
'/login/:stateId': wrap({
|
|
||||||
asyncComponent: () => import('./OutOfBandAuth.svelte'),
|
|
||||||
conditions: [requireLogin],
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
|
|
||||||
init()
|
init()
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
@ -71,9 +38,9 @@ init()
|
||||||
<Spinner />
|
<Spinner />
|
||||||
{:else}
|
{:else}
|
||||||
<div class="d-flex align-items-center mt-5 mb-5">
|
<div class="d-flex align-items-center mt-5 mb-5">
|
||||||
<a class="logo" href="/@warpgate">
|
<div class="logo">
|
||||||
<Logo />
|
<Logo />
|
||||||
</a>
|
</div>
|
||||||
|
|
||||||
{#if $serverInfo?.username}
|
{#if $serverInfo?.username}
|
||||||
<div class="ms-auto">
|
<div class="ms-auto">
|
||||||
|
@ -89,7 +56,12 @@ init()
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<main>
|
<main>
|
||||||
<Router {routes}/>
|
{#if $serverInfo?.username}
|
||||||
|
<TargetList
|
||||||
|
on:navigation={() => redirecting = true} />
|
||||||
|
{:else}
|
||||||
|
<Login />
|
||||||
|
{/if}
|
||||||
</main>
|
</main>
|
||||||
|
|
||||||
<footer class="mt-5">
|
<footer class="mt-5">
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
<script lang="ts">
|
<script lang="ts">
|
||||||
import { get } from 'svelte/store'
|
import { replace } from 'svelte-spa-router'
|
||||||
import { querystring, replace } from 'svelte-spa-router'
|
|
||||||
import { Alert, FormGroup, Spinner } from 'sveltestrap'
|
import { Alert, FormGroup, Spinner } from 'sveltestrap'
|
||||||
import Fa from 'svelte-fa'
|
import Fa from 'svelte-fa'
|
||||||
import { faArrowRight } from '@fortawesome/free-solid-svg-icons'
|
import { faArrowRight } from '@fortawesome/free-solid-svg-icons'
|
||||||
|
@ -10,47 +9,25 @@ import { api, ApiAuthState, LoginFailureResponseFromJSON, SsoProviderDescription
|
||||||
import { reloadServerInfo } from 'gateway/lib/store'
|
import { reloadServerInfo } from 'gateway/lib/store'
|
||||||
import AsyncButton from 'common/AsyncButton.svelte'
|
import AsyncButton from 'common/AsyncButton.svelte'
|
||||||
|
|
||||||
export let params: { stateId?: string } = {}
|
|
||||||
|
|
||||||
let error: Error|null = null
|
let error: Error|null = null
|
||||||
let username = ''
|
let username = ''
|
||||||
let password = ''
|
let password = ''
|
||||||
let otp = ''
|
let otp = ''
|
||||||
let busy = false
|
let busy = false
|
||||||
|
|
||||||
let authState: ApiAuthState|undefined = undefined
|
let authState = ApiAuthState.NotStarted
|
||||||
|
|
||||||
let ssoProvidersPromise = api.getSsoProviders()
|
let ssoProvidersPromise = api.getSsoProviders()
|
||||||
|
|
||||||
const nextURL = new URLSearchParams(get(querystring)).get('next') ?? undefined
|
const nextURL = new URLSearchParams(location.search).get('next')
|
||||||
const serverErrorMessage = new URLSearchParams(location.search).get('login_error')
|
const serverErrorMessage = new URLSearchParams(location.search).get('login_error')
|
||||||
|
|
||||||
async function init () {
|
async function init () {
|
||||||
try {
|
authState = (await api.getAuthState()).state
|
||||||
authState = (await api.getDefaultAuthState()).state
|
|
||||||
} catch (err) {
|
|
||||||
if (err.status) {
|
|
||||||
const response = err as Response
|
|
||||||
if (response.status === 404) {
|
|
||||||
authState = ApiAuthState.NotStarted
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continueWithState()
|
continueWithState()
|
||||||
}
|
}
|
||||||
|
|
||||||
function success () {
|
|
||||||
if (nextURL) {
|
|
||||||
replace(nextURL)
|
|
||||||
} else {
|
|
||||||
replace('/')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function continueWithState () {
|
async function continueWithState () {
|
||||||
if (authState === ApiAuthState.Success) {
|
|
||||||
success()
|
|
||||||
}
|
|
||||||
if (authState === ApiAuthState.SsoNeeded) {
|
if (authState === ApiAuthState.SsoNeeded) {
|
||||||
const providers = await ssoProvidersPromise
|
const providers = await ssoProvidersPromise
|
||||||
if (!providers.length) {
|
if (!providers.length) {
|
||||||
|
@ -88,15 +65,18 @@ async function _login () {
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
if (nextURL) {
|
||||||
|
location.href = nextURL
|
||||||
|
} else {
|
||||||
await reloadServerInfo()
|
await reloadServerInfo()
|
||||||
success()
|
replace('/')
|
||||||
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err.status) {
|
if (err.status) {
|
||||||
const response = err as Response
|
const response = err as Response
|
||||||
if (response.status === 401) {
|
if (response.status === 401) {
|
||||||
const failure = LoginFailureResponseFromJSON(await response.json())
|
const failure = LoginFailureResponseFromJSON(await response.json())
|
||||||
authState = failure.state
|
authState = failure.state
|
||||||
|
|
||||||
continueWithState()
|
continueWithState()
|
||||||
} else {
|
} else {
|
||||||
error = new Error(await response.text())
|
error = new Error(await response.text())
|
||||||
|
@ -116,8 +96,8 @@ function onInputKey (event: KeyboardEvent) {
|
||||||
async function startSSO (provider: SsoProviderDescription) {
|
async function startSSO (provider: SsoProviderDescription) {
|
||||||
busy = true
|
busy = true
|
||||||
try {
|
try {
|
||||||
const p = await api.startSso({ name: provider.name, next: nextURL })
|
const params = await api.startSso(provider)
|
||||||
location.href = p.url
|
location.href = params.url
|
||||||
} catch {
|
} catch {
|
||||||
busy = false
|
busy = false
|
||||||
}
|
}
|
||||||
|
@ -135,10 +115,6 @@ async function startSSO (provider: SsoProviderDescription) {
|
||||||
<h1>Continue login</h1>
|
<h1>Continue login</h1>
|
||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
{#if params.stateId}
|
|
||||||
//todo
|
|
||||||
loggin in for auth state id {params.stateId}
|
|
||||||
{/if}
|
|
||||||
{#if authState === ApiAuthState.OtpNeeded}
|
{#if authState === ApiAuthState.OtpNeeded}
|
||||||
<FormGroup floating label="One-time password">
|
<FormGroup floating label="One-time password">
|
||||||
<!-- svelte-ignore a11y-autofocus -->
|
<!-- svelte-ignore a11y-autofocus -->
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
<script lang="ts">
|
|
||||||
import { Alert, Spinner } from 'sveltestrap'
|
|
||||||
|
|
||||||
import { api, ApiAuthState, AuthStateResponseInternal } from 'gateway/lib/api'
|
|
||||||
import AsyncButton from 'common/AsyncButton.svelte'
|
|
||||||
|
|
||||||
export let params: { stateId: string }
|
|
||||||
let authState: AuthStateResponseInternal
|
|
||||||
|
|
||||||
async function reload () {
|
|
||||||
authState = await api.getAuthState({ id: params.stateId })
|
|
||||||
}
|
|
||||||
|
|
||||||
async function init () {
|
|
||||||
await reload()
|
|
||||||
}
|
|
||||||
|
|
||||||
async function approve () {
|
|
||||||
api.approveAuth({ id: params.stateId })
|
|
||||||
await reload()
|
|
||||||
}
|
|
||||||
|
|
||||||
async function reject () {
|
|
||||||
api.rejectAuth({ id: params.stateId })
|
|
||||||
await reload()
|
|
||||||
}
|
|
||||||
</script>
|
|
||||||
|
|
||||||
{#await init()}
|
|
||||||
<Spinner />
|
|
||||||
{:then}
|
|
||||||
<div class="page-summary-bar">
|
|
||||||
<h1>Authorization request</h1>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<p>Authorize this {authState.protocol} session?</p>
|
|
||||||
|
|
||||||
{#if authState.state === ApiAuthState.Success}
|
|
||||||
<Alert color="success">
|
|
||||||
Approved
|
|
||||||
</Alert>
|
|
||||||
{:else if authState.state === ApiAuthState.Failed}
|
|
||||||
<Alert color="danger">
|
|
||||||
Rejected
|
|
||||||
</Alert>
|
|
||||||
{:else}
|
|
||||||
<div class="d-flex">
|
|
||||||
<AsyncButton
|
|
||||||
color="primary"
|
|
||||||
class="d-flex align-items-center ms-auto"
|
|
||||||
click={approve}
|
|
||||||
>
|
|
||||||
Authorize
|
|
||||||
</AsyncButton>
|
|
||||||
<AsyncButton
|
|
||||||
outline
|
|
||||||
color="secondary"
|
|
||||||
class="d-flex align-items-center ms-2"
|
|
||||||
click={reject}
|
|
||||||
>
|
|
||||||
Reject
|
|
||||||
</AsyncButton>
|
|
||||||
</div>
|
|
||||||
{/if}
|
|
||||||
{/await}
|
|
|
@ -71,16 +71,6 @@
|
||||||
"operationId": "otpLogin"
|
"operationId": "otpLogin"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/auth/logout": {
|
|
||||||
"post": {
|
|
||||||
"responses": {
|
|
||||||
"201": {
|
|
||||||
"description": ""
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"operationId": "logout"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/auth/state": {
|
"/auth/state": {
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -93,108 +83,19 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"404": {
|
|
||||||
"description": ""
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"operationId": "getDefaultAuthState"
|
"operationId": "getAuthState"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/auth/state/{id}": {
|
"/auth/logout": {
|
||||||
"get": {
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "id",
|
|
||||||
"schema": {
|
|
||||||
"type": "string",
|
|
||||||
"format": "uuid"
|
|
||||||
},
|
|
||||||
"in": "path",
|
|
||||||
"required": true,
|
|
||||||
"deprecated": false
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/AuthStateResponseInternal"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"404": {
|
|
||||||
"description": ""
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"operationId": "get_auth_state"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/auth/state/{id}/approve": {
|
|
||||||
"post": {
|
"post": {
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "id",
|
|
||||||
"schema": {
|
|
||||||
"type": "string",
|
|
||||||
"format": "uuid"
|
|
||||||
},
|
|
||||||
"in": "path",
|
|
||||||
"required": true,
|
|
||||||
"deprecated": false
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"201": {
|
||||||
"description": "",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/AuthStateResponseInternal"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"404": {
|
|
||||||
"description": ""
|
"description": ""
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"operationId": "approve_auth"
|
"operationId": "logout"
|
||||||
}
|
|
||||||
},
|
|
||||||
"/auth/state/{id}/reject": {
|
|
||||||
"post": {
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "id",
|
|
||||||
"schema": {
|
|
||||||
"type": "string",
|
|
||||||
"format": "uuid"
|
|
||||||
},
|
|
||||||
"in": "path",
|
|
||||||
"required": true,
|
|
||||||
"deprecated": false
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/AuthStateResponseInternal"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"404": {
|
|
||||||
"description": ""
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"operationId": "reject_auth"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/info": {
|
"/info": {
|
||||||
|
@ -286,15 +187,6 @@
|
||||||
"in": "path",
|
"in": "path",
|
||||||
"required": true,
|
"required": true,
|
||||||
"deprecated": false
|
"deprecated": false
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "next",
|
|
||||||
"schema": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"in": "query",
|
|
||||||
"required": false,
|
|
||||||
"deprecated": false
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -326,21 +218,15 @@
|
||||||
"PasswordNeeded",
|
"PasswordNeeded",
|
||||||
"OtpNeeded",
|
"OtpNeeded",
|
||||||
"SsoNeeded",
|
"SsoNeeded",
|
||||||
"WebUserApprovalNeeded",
|
|
||||||
"PublicKeyNeeded",
|
|
||||||
"Success"
|
"Success"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"AuthStateResponseInternal": {
|
"AuthStateResponseInternal": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
"protocol",
|
|
||||||
"state"
|
"state"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
"protocol": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"state": {
|
"state": {
|
||||||
"$ref": "#/components/schemas/ApiAuthState"
|
"$ref": "#/components/schemas/ApiAuthState"
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,9 +30,6 @@ pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Checking MySQL key".to_string())?;
|
.with_context(|| "Checking MySQL key".to_string())?;
|
||||||
}
|
}
|
||||||
if !config.store.sso_providers.is_empty() && config.store.external_host.is_none() {
|
|
||||||
anyhow::bail!("SSO requires the external_host config option");
|
|
||||||
}
|
|
||||||
info!("No problems found");
|
info!("No problems found");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,6 +77,7 @@ fn check_and_migrate_config(store: &mut serde_yaml::Value) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub fn watch_config<P: AsRef<Path> + Send + 'static>(
|
pub fn watch_config<P: AsRef<Path> + Send + 'static>(
|
||||||
path: P,
|
path: P,
|
||||||
config: Arc<Mutex<WarpgateConfig>>,
|
config: Arc<Mutex<WarpgateConfig>>,
|
||||||
|
|
Loading…
Reference in a new issue