mirror of
https://github.com/warp-tech/warpgate.git
synced 2024-09-20 06:46:17 +08:00
parent
c6885f18c3
commit
c8a004f756
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
|
@ -12,6 +12,7 @@ jobs:
|
|||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
target: x86_64-unknown-linux-gnu
|
||||
override: true
|
||||
|
||||
|
@ -41,6 +42,7 @@ jobs:
|
|||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
toolchain: nightly
|
||||
use-cross: true
|
||||
args: --release --target x86_64-unknown-linux-gnu -Ztarget-applies-to-host
|
||||
|
||||
|
|
6
Cargo.lock
generated
6
Cargo.lock
generated
|
@ -3143,9 +3143,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "russh"
|
||||
version = "0.34.0-beta.8"
|
||||
version = "0.34.0-beta.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ccd8be93ee0b54a8a6b74c77ecef946185f0acbfb5234ea66666887621381e85"
|
||||
checksum = "4f3bb72a66e32d52e0e258627d141d5c93b408e050f15033699caa836d064c7e"
|
||||
dependencies = [
|
||||
"aes 0.8.1",
|
||||
"aes-gcm 0.10.1",
|
||||
|
@ -4620,7 +4620,6 @@ dependencies = [
|
|||
"bytes 1.2.1",
|
||||
"chrono",
|
||||
"data-encoding",
|
||||
"futures",
|
||||
"humantime-serde",
|
||||
"lazy_static",
|
||||
"once_cell",
|
||||
|
@ -4692,7 +4691,6 @@ version = "0.4.0"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"cookie",
|
||||
"data-encoding",
|
||||
"delegate",
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
nightly-2022-08-01
|
2
rust-toolchain.toml
Normal file
2
rust-toolchain.toml
Normal file
|
@ -0,0 +1,2 @@
|
|||
[toolchain]
|
||||
channel = "nightly-2022-07-22"
|
|
@ -13,7 +13,6 @@ chrono = { version = "0.4", features = ["serde"] }
|
|||
data-encoding = "2.3"
|
||||
humantime-serde = "1.1"
|
||||
lazy_static = "1.4"
|
||||
futures = "0.3"
|
||||
once_cell = "1.10"
|
||||
packet = "0.1"
|
||||
password-hash = "0.4"
|
||||
|
|
|
@ -13,8 +13,6 @@ pub enum CredentialKind {
|
|||
Otp,
|
||||
#[serde(rename = "sso")]
|
||||
Sso,
|
||||
#[serde(rename = "web")]
|
||||
WebUserApproval,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
|
@ -29,7 +27,6 @@ pub enum AuthCredential {
|
|||
provider: String,
|
||||
email: String,
|
||||
},
|
||||
WebUserApproval,
|
||||
}
|
||||
|
||||
impl AuthCredential {
|
||||
|
@ -39,7 +36,6 @@ impl AuthCredential {
|
|||
Self::PublicKey { .. } => CredentialKind::PublicKey,
|
||||
Self::Otp { .. } => CredentialKind::Otp,
|
||||
Self::Sso { .. } => CredentialKind::Sso,
|
||||
Self::WebUserApproval => CredentialKind::WebUserApproval,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashSet;
|
||||
|
||||
use super::{AuthCredential, CredentialKind};
|
||||
use crate::UserRequireCredentialsPolicy;
|
||||
|
||||
pub enum CredentialPolicyResponse {
|
||||
Ok,
|
||||
Need(HashSet<CredentialKind>),
|
||||
NeedMoreCredentials,
|
||||
Need(CredentialKind),
|
||||
}
|
||||
|
||||
pub trait CredentialPolicy {
|
||||
|
@ -15,71 +17,36 @@ pub trait CredentialPolicy {
|
|||
) -> CredentialPolicyResponse;
|
||||
}
|
||||
|
||||
pub struct AnySingleCredentialPolicy {
|
||||
pub supported_credential_types: HashSet<CredentialKind>,
|
||||
}
|
||||
|
||||
pub struct AllCredentialsPolicy {
|
||||
pub required_credential_types: HashSet<CredentialKind>,
|
||||
pub supported_credential_types: HashSet<CredentialKind>,
|
||||
}
|
||||
|
||||
pub struct PerProtocolCredentialPolicy {
|
||||
pub protocols: HashMap<&'static str, Box<dyn CredentialPolicy + Send + Sync>>,
|
||||
pub default: Box<dyn CredentialPolicy + Send + Sync>,
|
||||
}
|
||||
|
||||
impl CredentialPolicy for AnySingleCredentialPolicy {
|
||||
fn is_sufficient(
|
||||
&self,
|
||||
_protocol: &str,
|
||||
valid_credentials: &[AuthCredential],
|
||||
) -> CredentialPolicyResponse {
|
||||
if valid_credentials.is_empty() {
|
||||
CredentialPolicyResponse::Need(
|
||||
self.supported_credential_types
|
||||
.clone()
|
||||
.into_iter()
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
CredentialPolicyResponse::Ok
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialPolicy for AllCredentialsPolicy {
|
||||
fn is_sufficient(
|
||||
&self,
|
||||
_protocol: &str,
|
||||
valid_credentials: &[AuthCredential],
|
||||
) -> CredentialPolicyResponse {
|
||||
let valid_credential_types: HashSet<CredentialKind> =
|
||||
valid_credentials.iter().map(|x| x.kind()).collect();
|
||||
|
||||
if valid_credential_types.is_superset(&self.required_credential_types) {
|
||||
CredentialPolicyResponse::Ok
|
||||
} else {
|
||||
CredentialPolicyResponse::Need(
|
||||
self.required_credential_types
|
||||
.difference(&valid_credential_types)
|
||||
.cloned()
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialPolicy for PerProtocolCredentialPolicy {
|
||||
impl CredentialPolicy for UserRequireCredentialsPolicy {
|
||||
fn is_sufficient(
|
||||
&self,
|
||||
protocol: &str,
|
||||
valid_credentials: &[AuthCredential],
|
||||
) -> CredentialPolicyResponse {
|
||||
if let Some(policy) = self.protocols.get(protocol) {
|
||||
policy.is_sufficient(protocol, valid_credentials)
|
||||
let required_kinds = match protocol {
|
||||
"SSH" => &self.ssh,
|
||||
"HTTP" => &self.http,
|
||||
"MySQL" => &self.mysql,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
if let Some(required_kinds) = required_kinds {
|
||||
let mut remaining_required_kinds = HashSet::<CredentialKind>::new();
|
||||
remaining_required_kinds.extend(required_kinds);
|
||||
for kind in required_kinds {
|
||||
if valid_credentials.iter().any(|x| x.kind() == *kind) {
|
||||
remaining_required_kinds.remove(kind);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(kind) = remaining_required_kinds.into_iter().next() {
|
||||
CredentialPolicyResponse::Need(kind)
|
||||
} else {
|
||||
CredentialPolicyResponse::Ok
|
||||
}
|
||||
} else if valid_credentials.is_empty() {
|
||||
CredentialPolicyResponse::NeedMoreCredentials
|
||||
} else {
|
||||
self.default.is_sufficient(protocol, valid_credentials)
|
||||
CredentialPolicyResponse::Ok
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,66 +1,71 @@
|
|||
use uuid::Uuid;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{AuthCredential, CredentialPolicy, CredentialPolicyResponse};
|
||||
use crate::AuthResult;
|
||||
|
||||
#[allow(clippy::unwrap_used)]
|
||||
pub static TIMEOUT: Lazy<Duration> = Lazy::new(|| Duration::from_secs(60 * 10));
|
||||
|
||||
pub struct AuthState {
|
||||
id: Uuid,
|
||||
username: String,
|
||||
protocol: String,
|
||||
force_rejected: bool,
|
||||
policy: Box<dyn CredentialPolicy + Sync + Send>,
|
||||
policy: Option<Box<dyn CredentialPolicy + Sync + Send>>,
|
||||
valid_credentials: Vec<AuthCredential>,
|
||||
started_at: Instant,
|
||||
}
|
||||
|
||||
impl AuthState {
|
||||
pub(crate) fn new(
|
||||
id: Uuid,
|
||||
pub fn new(
|
||||
username: String,
|
||||
protocol: String,
|
||||
policy: Box<dyn CredentialPolicy + Sync + Send>,
|
||||
policy: Option<Box<dyn CredentialPolicy + Sync + Send>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
username,
|
||||
protocol,
|
||||
force_rejected: false,
|
||||
policy,
|
||||
valid_credentials: vec![],
|
||||
started_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &Uuid {
|
||||
&self.id
|
||||
}
|
||||
|
||||
pub fn username(&self) -> &str {
|
||||
&self.username
|
||||
}
|
||||
|
||||
pub fn protocol(&self) -> &str {
|
||||
&self.protocol
|
||||
}
|
||||
|
||||
pub fn add_valid_credential(&mut self, credential: AuthCredential) {
|
||||
self.valid_credentials.push(credential);
|
||||
}
|
||||
|
||||
pub fn reject(&mut self) {
|
||||
self.force_rejected = true;
|
||||
pub fn is_expired(&self) -> bool {
|
||||
self.started_at.elapsed() > *TIMEOUT
|
||||
}
|
||||
|
||||
pub fn verify(&self) -> AuthResult {
|
||||
if self.force_rejected {
|
||||
if self.valid_credentials.is_empty() {
|
||||
warn!(
|
||||
username=%self.username,
|
||||
"No matching valid credentials"
|
||||
);
|
||||
return AuthResult::Rejected;
|
||||
}
|
||||
match self
|
||||
.policy
|
||||
.is_sufficient(&self.protocol, &self.valid_credentials[..])
|
||||
{
|
||||
CredentialPolicyResponse::Ok => AuthResult::Accepted {
|
||||
username: self.username.clone(),
|
||||
},
|
||||
CredentialPolicyResponse::Need(kinds) => AuthResult::Need(kinds),
|
||||
|
||||
if let Some(ref policy) = self.policy {
|
||||
match policy.is_sufficient(&self.protocol, &self.valid_credentials[..]) {
|
||||
CredentialPolicyResponse::Ok => {}
|
||||
CredentialPolicyResponse::Need(kind) => {
|
||||
return AuthResult::Need(kind);
|
||||
}
|
||||
CredentialPolicyResponse::NeedMoreCredentials => {
|
||||
return AuthResult::Rejected;
|
||||
}
|
||||
}
|
||||
}
|
||||
AuthResult::Accepted {
|
||||
username: self.username.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,32 +1,15 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::sync::{broadcast, Mutex};
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::AuthState;
|
||||
use crate::{AuthResult, ConfigProvider, WarpgateError};
|
||||
|
||||
#[allow(clippy::unwrap_used)]
|
||||
pub static TIMEOUT: Lazy<Duration> = Lazy::new(|| Duration::from_secs(60 * 10));
|
||||
|
||||
struct AuthCompletionSignal {
|
||||
sender: broadcast::Sender<AuthResult>,
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
impl AuthCompletionSignal {
|
||||
pub fn is_expired(&self) -> bool {
|
||||
self.created_at.elapsed() > *TIMEOUT
|
||||
}
|
||||
}
|
||||
use crate::{ConfigProvider, WarpgateError};
|
||||
|
||||
pub struct AuthStateStore {
|
||||
config_provider: Arc<Mutex<dyn ConfigProvider + Send + 'static>>,
|
||||
store: HashMap<Uuid, (Arc<Mutex<AuthState>>, Instant)>,
|
||||
completion_signals: HashMap<Uuid, AuthCompletionSignal>,
|
||||
store: HashMap<Uuid, AuthState>,
|
||||
}
|
||||
|
||||
impl AuthStateStore {
|
||||
|
@ -34,66 +17,47 @@ impl AuthStateStore {
|
|||
Self {
|
||||
store: HashMap::new(),
|
||||
config_provider,
|
||||
completion_signals: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contains_key(&self, id: &Uuid) -> bool {
|
||||
pub fn contains_key(&mut self, id: &Uuid) -> bool {
|
||||
self.store.contains_key(id)
|
||||
}
|
||||
|
||||
pub fn get(&self, id: &Uuid) -> Option<Arc<Mutex<AuthState>>> {
|
||||
self.store.get(id).map(|x| x.0.clone())
|
||||
pub fn get_mut(&mut self, id: &Uuid) -> Option<&mut AuthState> {
|
||||
self.store.get_mut(id)
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
&mut self,
|
||||
username: &str,
|
||||
protocol: &str,
|
||||
) -> Result<(Uuid, Arc<Mutex<AuthState>>), WarpgateError> {
|
||||
) -> Result<(Uuid, &mut AuthState), WarpgateError> {
|
||||
let id = Uuid::new_v4();
|
||||
let Some(policy) = self.config_provider
|
||||
.lock()
|
||||
.await
|
||||
.get_credential_policy(username)
|
||||
.await? else {
|
||||
return Err(WarpgateError::UserNotFound)
|
||||
};
|
||||
|
||||
let state = AuthState::new(id, username.to_string(), protocol.to_string(), policy);
|
||||
self.store
|
||||
.insert(id, (Arc::new(Mutex::new(state)), Instant::now()));
|
||||
let state = AuthState::new(
|
||||
username.to_string(),
|
||||
protocol.to_string(),
|
||||
self.config_provider
|
||||
.lock()
|
||||
.await
|
||||
.get_credential_policy(username)
|
||||
.await?,
|
||||
);
|
||||
self.store.insert(id, state);
|
||||
|
||||
#[allow(clippy::unwrap_used)]
|
||||
Ok((id, self.get(&id).unwrap()))
|
||||
}
|
||||
|
||||
pub fn subscribe(&mut self, id: Uuid) -> broadcast::Receiver<AuthResult> {
|
||||
let signal = self.completion_signals.entry(id).or_insert_with(|| {
|
||||
let (sender, _) = broadcast::channel(1);
|
||||
AuthCompletionSignal {
|
||||
sender,
|
||||
created_at: Instant::now(),
|
||||
}
|
||||
});
|
||||
|
||||
signal.sender.subscribe()
|
||||
}
|
||||
|
||||
pub async fn complete(&mut self, id: &Uuid) {
|
||||
let Some((state, _)) = self.store.get(id) else {
|
||||
return
|
||||
};
|
||||
if let Some(sig) = self.completion_signals.remove(id) {
|
||||
let _ = sig.sender.send(state.lock().await.verify());
|
||||
}
|
||||
Ok((id, self.store.get_mut(&id).unwrap()))
|
||||
}
|
||||
|
||||
pub async fn vacuum(&mut self) {
|
||||
self.store
|
||||
.retain(|_, (_, started_at)| started_at.elapsed() < *TIMEOUT);
|
||||
|
||||
self.completion_signals
|
||||
.retain(|_, signal| !signal.is_expired());
|
||||
let mut to_remove = vec![];
|
||||
for (id, state) in self.store.iter() {
|
||||
if state.is_expired() {
|
||||
to_remove.push(*id);
|
||||
}
|
||||
}
|
||||
for id in to_remove {
|
||||
self.store.remove(&id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,12 +5,11 @@ use std::time::Duration;
|
|||
|
||||
use poem_openapi::{Enum, Object, Union};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use url::Url;
|
||||
use warpgate_sso::SsoProviderConfig;
|
||||
|
||||
use crate::auth::CredentialKind;
|
||||
use crate::helpers::otp::OtpSecretKey;
|
||||
use crate::{ListenEndpoint, Secret, WarpgateError};
|
||||
use crate::{ListenEndpoint, Secret};
|
||||
|
||||
const fn _default_true() -> bool {
|
||||
true
|
||||
|
@ -427,24 +426,3 @@ pub struct WarpgateConfig {
|
|||
pub store: WarpgateConfigStore,
|
||||
pub paths_relative_to: PathBuf,
|
||||
}
|
||||
|
||||
impl WarpgateConfig {
|
||||
pub fn construct_external_url(
|
||||
&self,
|
||||
fallback_host: Option<&str>,
|
||||
) -> Result<Url, WarpgateError> {
|
||||
let ext_host = self.store.external_host.as_deref().or(fallback_host);
|
||||
let Some(ext_host) = ext_host else {
|
||||
return Err(WarpgateError::ExternalHostNotSet);
|
||||
};
|
||||
let ext_port = self.store.http.listen.port();
|
||||
|
||||
let mut url = Url::parse(&format!("https://{ext_host}/"))?;
|
||||
|
||||
if ext_port != 443 {
|
||||
let _ = url.set_port(Some(ext_port));
|
||||
}
|
||||
|
||||
Ok(url)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
@ -11,10 +11,7 @@ use uuid::Uuid;
|
|||
use warpgate_db_entities::Ticket;
|
||||
|
||||
use super::ConfigProvider;
|
||||
use crate::auth::{
|
||||
AllCredentialsPolicy, AnySingleCredentialPolicy, AuthCredential, CredentialKind,
|
||||
CredentialPolicy, PerProtocolCredentialPolicy,
|
||||
};
|
||||
use crate::auth::{AuthCredential, CredentialPolicy};
|
||||
use crate::helpers::hash::verify_password_hash;
|
||||
use crate::helpers::otp::verify_totp;
|
||||
use crate::{Target, User, UserAuthCredential, UserSnapshot, WarpgateConfig, WarpgateError};
|
||||
|
@ -81,52 +78,9 @@ impl ConfigProvider for FileConfigProvider {
|
|||
return Ok(None);
|
||||
};
|
||||
|
||||
let supported_credential_types: HashSet<CredentialKind> =
|
||||
user.credentials.iter().map(|x| x.kind()).collect();
|
||||
let default_policy = Box::new(AnySingleCredentialPolicy {
|
||||
supported_credential_types: supported_credential_types.clone(),
|
||||
}) as Box<dyn CredentialPolicy + Sync + Send>;
|
||||
|
||||
if let Some(req) = user.require {
|
||||
let mut policy = PerProtocolCredentialPolicy {
|
||||
default: default_policy,
|
||||
protocols: HashMap::new(),
|
||||
};
|
||||
|
||||
if let Some(p) = req.http {
|
||||
policy.protocols.insert(
|
||||
"HTTP",
|
||||
Box::new(AllCredentialsPolicy {
|
||||
supported_credential_types: supported_credential_types.clone(),
|
||||
required_credential_types: p.into_iter().collect(),
|
||||
}),
|
||||
);
|
||||
}
|
||||
if let Some(p) = req.mysql {
|
||||
policy.protocols.insert(
|
||||
"MySQL",
|
||||
Box::new(AllCredentialsPolicy {
|
||||
supported_credential_types: supported_credential_types.clone(),
|
||||
required_credential_types: p.into_iter().collect(),
|
||||
}),
|
||||
);
|
||||
}
|
||||
if let Some(p) = req.ssh {
|
||||
policy.protocols.insert(
|
||||
"SSH",
|
||||
Box::new(AllCredentialsPolicy {
|
||||
supported_credential_types,
|
||||
required_credential_types: p.into_iter().collect(),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Some(
|
||||
Box::new(policy) as Box<dyn CredentialPolicy + Sync + Send>
|
||||
))
|
||||
} else {
|
||||
Ok(Some(default_policy))
|
||||
}
|
||||
Ok(user
|
||||
.require
|
||||
.map(|r| Box::new(r) as Box<dyn CredentialPolicy + Sync + Send>))
|
||||
}
|
||||
|
||||
async fn username_for_sso_credential(
|
||||
|
@ -239,7 +193,6 @@ impl ConfigProvider for FileConfigProvider {
|
|||
}
|
||||
return Ok(false);
|
||||
}
|
||||
_ => return Err(WarpgateError::InvalidCredentialType),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
mod file;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
@ -13,10 +12,11 @@ use warpgate_db_entities::Ticket;
|
|||
use crate::auth::{AuthCredential, CredentialKind, CredentialPolicy};
|
||||
use crate::{Secret, Target, UserSnapshot, WarpgateError};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
pub enum AuthResult {
|
||||
Accepted { username: String },
|
||||
Need(HashSet<CredentialKind>),
|
||||
Need(CredentialKind),
|
||||
NeedMoreCredentials,
|
||||
Rejected,
|
||||
}
|
||||
|
||||
|
|
|
@ -9,16 +9,8 @@ pub enum WarpgateError {
|
|||
DatabaseError(#[from] sea_orm::DbErr),
|
||||
#[error("ticket not found: {0}")]
|
||||
InvalidTicket(Uuid),
|
||||
#[error("invalid credential type")]
|
||||
InvalidCredentialType,
|
||||
#[error(transparent)]
|
||||
Other(Box<dyn Error + Send + Sync>),
|
||||
#[error("user not found")]
|
||||
UserNotFound,
|
||||
#[error("failed to parse url: {0}")]
|
||||
UrlParse(#[from] url::ParseError),
|
||||
#[error("external_url config option is not set")]
|
||||
ExternalHostNotSet,
|
||||
}
|
||||
|
||||
impl ResponseError for WarpgateError {
|
||||
|
|
|
@ -7,7 +7,6 @@ version = "0.4.0"
|
|||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
async-trait = "0.1"
|
||||
chrono = {version = "0.4", features = ["serde"]}
|
||||
cookie = "0.16"
|
||||
data-encoding = "2.3"
|
||||
delegate = "0.6"
|
||||
|
|
|
@ -3,18 +3,14 @@ use std::sync::Arc;
|
|||
use poem::session::Session;
|
||||
use poem::web::Data;
|
||||
use poem::Request;
|
||||
use poem_openapi::param::Path;
|
||||
use poem_openapi::payload::Json;
|
||||
use poem_openapi::{ApiResponse, Enum, Object, OpenApi};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::*;
|
||||
use uuid::Uuid;
|
||||
use warpgate_common::auth::{AuthCredential, AuthState, CredentialKind};
|
||||
use warpgate_common::auth::{AuthCredential, CredentialKind};
|
||||
use warpgate_common::{AuthResult, Secret, Services};
|
||||
|
||||
use crate::common::{
|
||||
authorize_session, endpoint_auth, get_auth_state_for_request, SessionAuthorization, SessionExt,
|
||||
};
|
||||
use crate::common::{authorize_session, get_auth_state_for_request, SessionExt};
|
||||
use crate::session::SessionStore;
|
||||
|
||||
pub struct Api;
|
||||
|
@ -37,8 +33,6 @@ enum ApiAuthState {
|
|||
PasswordNeeded,
|
||||
OtpNeeded,
|
||||
SsoNeeded,
|
||||
WebUserApprovalNeeded,
|
||||
PublicKeyNeeded,
|
||||
Success,
|
||||
}
|
||||
|
||||
|
@ -64,7 +58,6 @@ enum LogoutResponse {
|
|||
|
||||
#[derive(Object)]
|
||||
struct AuthStateResponseInternal {
|
||||
pub protocol: String,
|
||||
pub state: ApiAuthState,
|
||||
}
|
||||
|
||||
|
@ -72,22 +65,17 @@ struct AuthStateResponseInternal {
|
|||
enum AuthStateResponse {
|
||||
#[oai(status = 200)]
|
||||
Ok(Json<AuthStateResponseInternal>),
|
||||
#[oai(status = 404)]
|
||||
NotFound,
|
||||
}
|
||||
|
||||
impl From<AuthResult> for ApiAuthState {
|
||||
fn from(state: AuthResult) -> Self {
|
||||
match state {
|
||||
AuthResult::Rejected => ApiAuthState::Failed,
|
||||
AuthResult::Need(kinds) => match kinds.iter().next() {
|
||||
Some(CredentialKind::Password) => ApiAuthState::PasswordNeeded,
|
||||
Some(CredentialKind::Otp) => ApiAuthState::OtpNeeded,
|
||||
Some(CredentialKind::Sso) => ApiAuthState::SsoNeeded,
|
||||
Some(CredentialKind::WebUserApproval) => ApiAuthState::WebUserApprovalNeeded,
|
||||
Some(CredentialKind::PublicKey) => ApiAuthState::PublicKeyNeeded,
|
||||
None => ApiAuthState::Failed,
|
||||
},
|
||||
AuthResult::Need(CredentialKind::Password) => ApiAuthState::PasswordNeeded,
|
||||
AuthResult::Need(CredentialKind::Otp) => ApiAuthState::OtpNeeded,
|
||||
AuthResult::Need(CredentialKind::Sso) => ApiAuthState::SsoNeeded,
|
||||
AuthResult::Need(CredentialKind::PublicKey) => ApiAuthState::Failed,
|
||||
AuthResult::NeedMoreCredentials => ApiAuthState::Failed,
|
||||
AuthResult::Accepted { .. } => ApiAuthState::Success,
|
||||
}
|
||||
}
|
||||
|
@ -104,9 +92,8 @@ impl Api {
|
|||
body: Json<LoginRequest>,
|
||||
) -> poem::Result<LoginResponse> {
|
||||
let mut auth_state_store = services.auth_state_store.lock().await;
|
||||
let state_arc =
|
||||
let state =
|
||||
get_auth_state_for_request(&body.username, session, &mut auth_state_store).await?;
|
||||
let mut state = state_arc.lock().await;
|
||||
|
||||
let mut cp = services.config_provider.lock().await;
|
||||
|
||||
|
@ -120,7 +107,6 @@ impl Api {
|
|||
|
||||
match state.verify() {
|
||||
AuthResult::Accepted { username } => {
|
||||
auth_state_store.complete(state.id()).await;
|
||||
authorize_session(req, username).await?;
|
||||
Ok(LoginResponse::Success)
|
||||
}
|
||||
|
@ -145,14 +131,12 @@ impl Api {
|
|||
|
||||
let mut auth_state_store = services.auth_state_store.lock().await;
|
||||
|
||||
let Some(state_arc) = state_id.and_then(|id| auth_state_store.get(&id.0)) else {
|
||||
let Some(state) = state_id.and_then(|id| auth_state_store.get_mut(&id.0)) else {
|
||||
return Ok(LoginResponse::Failure(Json(LoginFailureResponse {
|
||||
state: ApiAuthState::NotStarted,
|
||||
})))
|
||||
};
|
||||
|
||||
let mut state = state_arc.lock().await;
|
||||
|
||||
let mut cp = services.config_provider.lock().await;
|
||||
|
||||
let otp_cred = AuthCredential::Otp(body.otp.clone().into());
|
||||
|
@ -162,7 +146,6 @@ impl Api {
|
|||
|
||||
match state.verify() {
|
||||
AuthResult::Accepted { username } => {
|
||||
auth_state_store.complete(state.id()).await;
|
||||
authorize_session(req, username).await?;
|
||||
Ok(LoginResponse::Success)
|
||||
}
|
||||
|
@ -172,6 +155,27 @@ impl Api {
|
|||
}
|
||||
}
|
||||
|
||||
#[oai(path = "/auth/state", method = "get", operation_id = "getAuthState")]
|
||||
async fn api_auth_state(
|
||||
&self,
|
||||
session: &Session,
|
||||
services: Data<&Services>,
|
||||
) -> poem::Result<AuthStateResponse> {
|
||||
let state_id = session.get_auth_state_id();
|
||||
|
||||
let mut auth_state_store = services.auth_state_store.lock().await;
|
||||
|
||||
let Some(state) = state_id.and_then(|id| auth_state_store.get_mut(&id.0)) else {
|
||||
return Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal {
|
||||
state: ApiAuthState::NotStarted,
|
||||
})));
|
||||
};
|
||||
|
||||
Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal {
|
||||
state: state.verify().into(),
|
||||
})))
|
||||
}
|
||||
|
||||
#[oai(path = "/auth/logout", method = "post", operation_id = "logout")]
|
||||
async fn api_auth_logout(
|
||||
&self,
|
||||
|
@ -183,129 +187,4 @@ impl Api {
|
|||
info!("Logged out");
|
||||
Ok(LogoutResponse::Success)
|
||||
}
|
||||
|
||||
#[oai(
|
||||
path = "/auth/state",
|
||||
method = "get",
|
||||
operation_id = "getDefaultAuthState"
|
||||
)]
|
||||
async fn api_default_auth_state(
|
||||
&self,
|
||||
session: &Session,
|
||||
services: Data<&Services>,
|
||||
) -> poem::Result<AuthStateResponse> {
|
||||
let Some(state_id) = session.get_auth_state_id() else {
|
||||
return Ok(AuthStateResponse::NotFound)
|
||||
};
|
||||
let store = services.auth_state_store.lock().await;
|
||||
let Some(state_arc) = store.get(&state_id.0) else {
|
||||
return Ok(AuthStateResponse::NotFound);
|
||||
};
|
||||
serialize_auth_state_inner(state_arc).await
|
||||
}
|
||||
|
||||
#[oai(
|
||||
path = "/auth/state/:id",
|
||||
method = "get",
|
||||
operation_id = "get_auth_state",
|
||||
transform = "endpoint_auth"
|
||||
)]
|
||||
async fn api_auth_state(
|
||||
&self,
|
||||
services: Data<&Services>,
|
||||
auth: Option<Data<&SessionAuthorization>>,
|
||||
id: Path<Uuid>,
|
||||
) -> poem::Result<AuthStateResponse> {
|
||||
let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else {
|
||||
return Ok(AuthStateResponse::NotFound);
|
||||
};
|
||||
serialize_auth_state_inner(state_arc).await
|
||||
}
|
||||
|
||||
#[oai(
|
||||
path = "/auth/state/:id/approve",
|
||||
method = "post",
|
||||
operation_id = "approve_auth",
|
||||
transform = "endpoint_auth"
|
||||
)]
|
||||
async fn api_approve_auth(
|
||||
&self,
|
||||
services: Data<&Services>,
|
||||
auth: Option<Data<&SessionAuthorization>>,
|
||||
id: Path<Uuid>,
|
||||
) -> poem::Result<AuthStateResponse> {
|
||||
let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else {
|
||||
return Ok(AuthStateResponse::NotFound);
|
||||
};
|
||||
|
||||
let auth_result = {
|
||||
let mut state = state_arc.lock().await;
|
||||
state.add_valid_credential(AuthCredential::WebUserApproval);
|
||||
state.verify()
|
||||
};
|
||||
|
||||
if let AuthResult::Accepted { .. } = auth_result {
|
||||
services.auth_state_store.lock().await.complete(&*id).await;
|
||||
}
|
||||
serialize_auth_state_inner(state_arc).await
|
||||
}
|
||||
|
||||
#[oai(
|
||||
path = "/auth/state/:id/reject",
|
||||
method = "post",
|
||||
operation_id = "reject_auth",
|
||||
transform = "endpoint_auth"
|
||||
)]
|
||||
async fn api_reject_auth(
|
||||
&self,
|
||||
services: Data<&Services>,
|
||||
auth: Option<Data<&SessionAuthorization>>,
|
||||
id: Path<Uuid>,
|
||||
) -> poem::Result<AuthStateResponse> {
|
||||
let Some(state_arc) = get_auth_state(&*id, *services, auth.map(|x|x.0)).await else {
|
||||
return Ok(AuthStateResponse::NotFound);
|
||||
};
|
||||
state_arc.lock().await.reject();
|
||||
services.auth_state_store.lock().await.complete(&*id).await;
|
||||
serialize_auth_state_inner(state_arc).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_auth_state(
|
||||
id: &Uuid,
|
||||
services: &Services,
|
||||
auth: Option<&SessionAuthorization>,
|
||||
) -> Option<Arc<Mutex<AuthState>>> {
|
||||
let store = services.auth_state_store.lock().await;
|
||||
|
||||
let Some(auth) = auth else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let SessionAuthorization::User(username) = auth else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let Some(state_arc) = store.get(&*id) else {
|
||||
return None;
|
||||
};
|
||||
|
||||
{
|
||||
let state = state_arc.lock().await;
|
||||
if state.username() != username {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(state_arc)
|
||||
}
|
||||
|
||||
async fn serialize_auth_state_inner(
|
||||
state_arc: Arc<Mutex<AuthState>>,
|
||||
) -> poem::Result<AuthStateResponse> {
|
||||
let state = state_arc.lock().await;
|
||||
Ok(AuthStateResponse::Ok(Json(AuthStateResponseInternal {
|
||||
protocol: state.protocol().to_string(),
|
||||
state: state.verify().into(),
|
||||
})))
|
||||
}
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use poem::session::Session;
|
||||
use poem::web::Data;
|
||||
use poem::Request;
|
||||
use poem_openapi::param::{Path, Query};
|
||||
use poem_openapi::param::Path;
|
||||
use poem_openapi::payload::Json;
|
||||
use poem_openapi::{ApiResponse, Object, OpenApi};
|
||||
use reqwest::Url;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use warpgate_common::Services;
|
||||
use warpgate_sso::{SsoClient, SsoLoginRequest};
|
||||
|
@ -30,7 +31,6 @@ pub static SSO_CONTEXT_SESSION_KEY: &str = "sso_request";
|
|||
pub struct SsoContext {
|
||||
pub provider: String,
|
||||
pub request: SsoLoginRequest,
|
||||
pub next_url: Option<String>,
|
||||
}
|
||||
|
||||
#[OpenApi]
|
||||
|
@ -46,14 +46,31 @@ impl Api {
|
|||
session: &Session,
|
||||
services: Data<&Services>,
|
||||
name: Path<String>,
|
||||
next: Query<Option<String>>,
|
||||
) -> poem::Result<StartSsoResponse> {
|
||||
let config = services.config.lock().await;
|
||||
|
||||
let name = name.0;
|
||||
let ext_host = config
|
||||
.store
|
||||
.external_host
|
||||
.as_deref()
|
||||
.or_else(|| req.original_uri().host());
|
||||
let Some(ext_host) = ext_host else {
|
||||
return Err(poem::Error::from_string("external_host config option is required for SSO", http::status::StatusCode::INTERNAL_SERVER_ERROR));
|
||||
};
|
||||
let ext_port = config.store.http.listen.port();
|
||||
|
||||
let mut return_url = config.construct_external_url(req.original_uri().host())?;
|
||||
return_url.set_path("@warpgate/api/sso/return");
|
||||
let mut return_url = Url::parse(&format!("https://{ext_host}/@warpgate/api/sso/return"))
|
||||
.map_err(|e| {
|
||||
poem::Error::from_string(
|
||||
format!("failed to construct the return URL: {e}"),
|
||||
http::status::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
})?;
|
||||
|
||||
if ext_port != 443 {
|
||||
let _ = return_url.set_port(Some(ext_port));
|
||||
}
|
||||
|
||||
let Some(provider_config) = config.store.sso_providers.iter().find(|p| p.name == *name) else {
|
||||
return Ok(StartSsoResponse::NotFound);
|
||||
|
@ -67,14 +84,10 @@ impl Api {
|
|||
.map_err(poem::error::InternalServerError)?;
|
||||
|
||||
let url = sso_req.auth_url().to_string();
|
||||
session.set(
|
||||
SSO_CONTEXT_SESSION_KEY,
|
||||
SsoContext {
|
||||
provider: name,
|
||||
request: sso_req,
|
||||
next_url: next.0.clone(),
|
||||
},
|
||||
);
|
||||
session.set(SSO_CONTEXT_SESSION_KEY, SsoContext {
|
||||
provider: name,
|
||||
request: sso_req,
|
||||
});
|
||||
|
||||
Ok(StartSsoResponse::Ok(Json(StartSsoResponseParams { url })))
|
||||
}
|
||||
|
|
|
@ -120,9 +120,8 @@ impl Api {
|
|||
};
|
||||
|
||||
let mut auth_state_store = services.auth_state_store.lock().await;
|
||||
let state_arc = get_auth_state_for_request(&username, session, &mut auth_state_store).await?;
|
||||
let state = get_auth_state_for_request(&username, session, &mut auth_state_store).await?;
|
||||
|
||||
let mut state = state_arc.lock().await;
|
||||
let mut cp = services.config_provider.lock().await;
|
||||
|
||||
if cp.validate_credential(&username, &cred).await? {
|
||||
|
@ -131,15 +130,11 @@ impl Api {
|
|||
|
||||
match state.verify() {
|
||||
AuthResult::Accepted { username } => {
|
||||
auth_state_store.complete(state.id()).await;
|
||||
authorize_session(req, username).await?;
|
||||
}
|
||||
_ => (),
|
||||
_ => ()
|
||||
}
|
||||
|
||||
Ok(Response::new(ReturnToSsoResponse::Ok).header(
|
||||
"Location",
|
||||
context.next_url.as_deref().unwrap_or("/@warpgate#/login"),
|
||||
))
|
||||
Ok(Response::new(ReturnToSsoResponse::Ok).header("Location", "/@warpgate"))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -170,18 +170,18 @@ pub fn gateway_redirect(req: &Request) -> Response {
|
|||
.unwrap_or("".into());
|
||||
|
||||
let path = format!(
|
||||
"/@warpgate#/login?next={}",
|
||||
"/@warpgate?next={}",
|
||||
utf8_percent_encode(&path, NON_ALPHANUMERIC),
|
||||
);
|
||||
|
||||
Redirect::temporary(path).into_response()
|
||||
}
|
||||
|
||||
pub async fn get_auth_state_for_request(
|
||||
pub async fn get_auth_state_for_request<'a>(
|
||||
username: &str,
|
||||
session: &Session,
|
||||
store: &mut AuthStateStore,
|
||||
) -> Result<Arc<Mutex<AuthState>>, WarpgateError> {
|
||||
store: &'a mut AuthStateStore,
|
||||
) -> Result<&'a mut AuthState, WarpgateError> {
|
||||
match session.get_auth_state_id() {
|
||||
Some(id) => {
|
||||
if !store.contains_key(&id.0) {
|
||||
|
@ -192,7 +192,7 @@ pub async fn get_auth_state_for_request(
|
|||
};
|
||||
|
||||
match session.get_auth_state_id() {
|
||||
Some(id) => Ok(store.get(&id.0).unwrap()),
|
||||
Some(id) => Ok(store.get_mut(&id.0).unwrap()),
|
||||
None => {
|
||||
let (id, state) = store
|
||||
.create(&username, crate::common::PROTOCOL_NAME)
|
||||
|
|
|
@ -7,7 +7,7 @@ use tokio::net::TcpStream;
|
|||
use tokio::sync::Mutex;
|
||||
use tracing::*;
|
||||
use uuid::Uuid;
|
||||
use warpgate_common::auth::{AuthCredential, AuthSelector};
|
||||
use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState};
|
||||
use warpgate_common::helpers::rng::get_crypto_rng;
|
||||
use warpgate_common::{
|
||||
authorize_ticket, AuthResult, Secret, Services, TargetMySqlOptions, TargetOptions,
|
||||
|
@ -180,20 +180,15 @@ impl MySqlSession {
|
|||
username,
|
||||
target_name,
|
||||
} => {
|
||||
let state_arc = self
|
||||
.services
|
||||
.auth_state_store
|
||||
.lock()
|
||||
.await
|
||||
.create(&username, crate::common::PROTOCOL_NAME)
|
||||
.await?
|
||||
.1;
|
||||
let mut state = state_arc.lock().await;
|
||||
|
||||
let user_auth_result = {
|
||||
let credential = AuthCredential::Password(password);
|
||||
|
||||
let mut cp = self.services.config_provider.lock().await;
|
||||
|
||||
let credential = AuthCredential::Password(password);
|
||||
let mut state = AuthState::new(
|
||||
username.clone(),
|
||||
crate::common::PROTOCOL_NAME.to_string(),
|
||||
cp.get_credential_policy(&username).await?,
|
||||
);
|
||||
if cp.validate_credential(&username, &credential).await? {
|
||||
state.add_valid_credential(credential);
|
||||
}
|
||||
|
@ -203,12 +198,6 @@ impl MySqlSession {
|
|||
|
||||
match user_auth_result {
|
||||
AuthResult::Accepted { username } => {
|
||||
self.services
|
||||
.auth_state_store
|
||||
.lock()
|
||||
.await
|
||||
.complete(state.id())
|
||||
.await;
|
||||
let target_auth_result = {
|
||||
self.services
|
||||
.config_provider
|
||||
|
@ -227,7 +216,9 @@ impl MySqlSession {
|
|||
}
|
||||
self.run_authorized(handshake, username, target_name).await
|
||||
}
|
||||
AuthResult::Rejected | AuthResult::Need(_) => fail(&mut self).await, // TODO SSO
|
||||
AuthResult::Rejected
|
||||
| AuthResult::Need(_)
|
||||
| AuthResult::NeedMoreCredentials => fail(&mut self).await, // TODO SSO
|
||||
}
|
||||
}
|
||||
AuthSelector::Ticket { secret } => {
|
||||
|
|
|
@ -12,7 +12,7 @@ bimap = "0.6"
|
|||
bytes = "1.2"
|
||||
dialoguer = "0.10"
|
||||
futures = "0.3"
|
||||
russh = {version = "0.34.0-beta.8", features = ["openssl"]}
|
||||
russh = {version = "0.34.0-beta.7", features = ["openssl"]}
|
||||
russh-keys = {version = "0.22.0-beta.4", features = ["openssl"]}
|
||||
sea-orm = {version = "^0.9", features = ["runtime-tokio-native-tls"], default-features = false}
|
||||
thiserror = "1.0"
|
||||
|
|
|
@ -23,7 +23,6 @@ pub async fn run_server(services: Services, address: SocketAddr) -> Result<()> {
|
|||
let config = services.config.lock().await;
|
||||
russh::server::Config {
|
||||
auth_rejection_time: std::time::Duration::from_secs(1),
|
||||
connection_timeout: Some(std::time::Duration::from_secs(300)),
|
||||
methods: MethodSet::PUBLICKEY | MethodSet::PASSWORD | MethodSet::KEYBOARD_INTERACTIVE,
|
||||
keys: load_host_keys(&config)?,
|
||||
..Default::default()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::borrow::Cow;
|
||||
use std::collections::hash_map::Entry::Vacant;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashMap;
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
@ -10,11 +10,11 @@ use anyhow::{Context, Result};
|
|||
use bimap::BiMap;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use russh::server::Session;
|
||||
use russh::{CryptoVec, MethodSet, Sig};
|
||||
use russh::{CryptoVec, Sig};
|
||||
use russh_keys::key::PublicKey;
|
||||
use russh_keys::PublicKeyBase64;
|
||||
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
||||
use tokio::sync::{broadcast, oneshot, Mutex};
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
use tracing::*;
|
||||
use uuid::Uuid;
|
||||
use warpgate_common::auth::{AuthCredential, AuthSelector, AuthState, CredentialKind};
|
||||
|
@ -53,12 +53,6 @@ enum Event {
|
|||
Client(RCEvent),
|
||||
}
|
||||
|
||||
enum KeyboardInteractiveState {
|
||||
None,
|
||||
OtpRequested,
|
||||
WebAuthRequested(broadcast::Receiver<AuthResult>),
|
||||
}
|
||||
|
||||
pub struct ServerSession {
|
||||
pub id: SessionId,
|
||||
username: Option<String>,
|
||||
|
@ -80,8 +74,7 @@ pub struct ServerSession {
|
|||
hub: EventHub<Event>,
|
||||
event_sender: EventSender<Event>,
|
||||
service_output: ServiceOutput,
|
||||
auth_state: Option<Arc<Mutex<AuthState>>>,
|
||||
keyboard_interactive_state: KeyboardInteractiveState,
|
||||
auth_state: Option<AuthState>,
|
||||
}
|
||||
|
||||
fn session_debug_tag(id: &SessionId, remote_address: &SocketAddr) -> String {
|
||||
|
@ -149,7 +142,6 @@ impl ServerSession {
|
|||
so_tx.send(BytesMut::from(data).freeze()).context("x")
|
||||
})),
|
||||
auth_state: None,
|
||||
keyboard_interactive_state: KeyboardInteractiveState::None,
|
||||
};
|
||||
|
||||
let this = Arc::new(Mutex::new(this));
|
||||
|
@ -225,23 +217,18 @@ impl ServerSession {
|
|||
Ok(this)
|
||||
}
|
||||
|
||||
async fn get_auth_state(&mut self, username: &str) -> Result<Arc<Mutex<AuthState>>> {
|
||||
async fn get_auth_state(&mut self, username: &str) -> Result<&mut AuthState> {
|
||||
#[allow(clippy::unwrap_used)]
|
||||
if self.auth_state.is_none()
|
||||
|| self.auth_state.as_ref().unwrap().lock().await.username() != username
|
||||
{
|
||||
let state = self
|
||||
.services
|
||||
.auth_state_store
|
||||
.lock()
|
||||
.await
|
||||
.create(username, crate::PROTOCOL_NAME)
|
||||
.await?
|
||||
.1;
|
||||
self.auth_state = Some(state);
|
||||
if self.auth_state.is_none() || self.auth_state.as_ref().unwrap().username() != username {
|
||||
let mut cp = self.services.config_provider.lock().await;
|
||||
self.auth_state = Some(AuthState::new(
|
||||
username.to_string(),
|
||||
crate::PROTOCOL_NAME.to_string(),
|
||||
cp.get_credential_policy(username).await?,
|
||||
));
|
||||
}
|
||||
#[allow(clippy::unwrap_used)]
|
||||
Ok(self.auth_state.as_ref().map(Clone::clone).unwrap())
|
||||
Ok(self.auth_state.as_mut().unwrap())
|
||||
}
|
||||
|
||||
pub fn make_logging_span(&self) -> tracing::Span {
|
||||
|
@ -952,17 +939,13 @@ impl ServerSession {
|
|||
.await
|
||||
{
|
||||
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
||||
Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
|
||||
proceed_with_methods: Some(MethodSet::all()),
|
||||
},
|
||||
Ok(AuthResult::Need(kinds)) => russh::server::Auth::Reject {
|
||||
proceed_with_methods: Some(self.get_remaining_auth_methods(kinds)),
|
||||
},
|
||||
Ok(AuthResult::Rejected) => russh::server::Auth::Reject,
|
||||
Ok(AuthResult::Need(_) | AuthResult::NeedMoreCredentials) => {
|
||||
russh::server::Auth::Reject
|
||||
}
|
||||
Err(error) => {
|
||||
error!(?error, "Failed to verify credentials");
|
||||
russh::server::Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
}
|
||||
russh::server::Auth::Reject
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -980,17 +963,13 @@ impl ServerSession {
|
|||
.await
|
||||
{
|
||||
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
||||
Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
},
|
||||
Ok(AuthResult::Need(_)) => russh::server::Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
},
|
||||
Ok(AuthResult::Rejected) => russh::server::Auth::Reject,
|
||||
Ok(AuthResult::Need(_) | AuthResult::NeedMoreCredentials) => {
|
||||
russh::server::Auth::Reject
|
||||
}
|
||||
Err(error) => {
|
||||
error!(?error, "Failed to verify credentials");
|
||||
russh::server::Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
}
|
||||
russh::server::Auth::Reject
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1003,111 +982,27 @@ impl ServerSession {
|
|||
let selector: AuthSelector = ssh_username.expose_secret().into();
|
||||
info!("Keyboard-interactive auth as {:?}", selector);
|
||||
|
||||
let cred;
|
||||
match &mut self.keyboard_interactive_state {
|
||||
KeyboardInteractiveState::None => {
|
||||
cred = None;
|
||||
}
|
||||
KeyboardInteractiveState::OtpRequested => {
|
||||
cred = response.map(AuthCredential::Otp);
|
||||
}
|
||||
KeyboardInteractiveState::WebAuthRequested(event) => {
|
||||
cred = None;
|
||||
let _ = event.recv().await;
|
||||
// the auth state has been updated by now
|
||||
}
|
||||
}
|
||||
|
||||
self.keyboard_interactive_state = KeyboardInteractiveState::None;
|
||||
let cred = response.map(AuthCredential::Otp);
|
||||
|
||||
match self.try_auth(&selector, cred).await {
|
||||
Ok(AuthResult::Accepted { .. }) => russh::server::Auth::Accept,
|
||||
Ok(AuthResult::Rejected) => russh::server::Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
Ok(
|
||||
AuthResult::Rejected
|
||||
| AuthResult::NeedMoreCredentials
|
||||
| AuthResult::Need(CredentialKind::Otp),
|
||||
) => russh::server::Auth::Partial {
|
||||
name: Cow::Borrowed("Two-factor authentication"),
|
||||
instructions: Cow::Borrowed(""),
|
||||
prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]),
|
||||
},
|
||||
Ok(AuthResult::Need(kinds)) => {
|
||||
if kinds.contains(&CredentialKind::Otp) {
|
||||
self.keyboard_interactive_state = KeyboardInteractiveState::OtpRequested;
|
||||
russh::server::Auth::Partial {
|
||||
name: Cow::Borrowed("Two-factor authentication"),
|
||||
instructions: Cow::Borrowed(""),
|
||||
prompts: Cow::Owned(vec![(Cow::Borrowed("One-time password: "), true)]),
|
||||
}
|
||||
} else if kinds.contains(&CredentialKind::WebUserApproval) {
|
||||
let Some(auth_state) = self.auth_state.as_ref() else {
|
||||
return russh::server::Auth::Reject { proceed_with_methods: None};
|
||||
};
|
||||
let auth_state_id = *auth_state.lock().await.id();
|
||||
let event = self
|
||||
.services
|
||||
.auth_state_store
|
||||
.lock()
|
||||
.await
|
||||
.subscribe(auth_state_id);
|
||||
self.keyboard_interactive_state =
|
||||
KeyboardInteractiveState::WebAuthRequested(event);
|
||||
|
||||
let mut login_url = match self
|
||||
.services
|
||||
.config
|
||||
.lock()
|
||||
.await
|
||||
.construct_external_url(None)
|
||||
{
|
||||
Ok(url) => url,
|
||||
Err(error) => {
|
||||
error!(?error, "Failed to construct external URL");
|
||||
return russh::server::Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
login_url.set_path("@warpgate");
|
||||
login_url
|
||||
.set_fragment(Some(&format!("/login?next=%2Flogin%2F{auth_state_id}")));
|
||||
|
||||
russh::server::Auth::Partial {
|
||||
name: Cow::Owned(format!(
|
||||
concat!(
|
||||
"----------------------------------------------------------------\n",
|
||||
"Warpgate authentication: please open {} in your browser\n",
|
||||
"----------------------------------------------------------------\n"
|
||||
),
|
||||
login_url
|
||||
)),
|
||||
instructions: Cow::Borrowed(""),
|
||||
prompts: Cow::Owned(vec![(Cow::Borrowed("Press Enter when done: "), true)]),
|
||||
}
|
||||
} else {
|
||||
russh::server::Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(AuthResult::Need(_)) => russh::server::Auth::Reject, // TODO SSO
|
||||
Err(error) => {
|
||||
error!(?error, "Failed to verify credentials");
|
||||
russh::server::Auth::Reject {
|
||||
proceed_with_methods: None,
|
||||
}
|
||||
russh::server::Auth::Reject
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_remaining_auth_methods(&self, kinds: HashSet<CredentialKind>) -> MethodSet {
|
||||
let mut m = MethodSet::empty();
|
||||
for kind in kinds {
|
||||
match kind {
|
||||
CredentialKind::Password => m.insert(MethodSet::PASSWORD),
|
||||
CredentialKind::Otp => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
|
||||
CredentialKind::WebUserApproval => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
|
||||
CredentialKind::PublicKey => m.insert(MethodSet::PUBLICKEY),
|
||||
CredentialKind::Sso => m.insert(MethodSet::KEYBOARD_INTERACTIVE),
|
||||
}
|
||||
}
|
||||
m
|
||||
}
|
||||
|
||||
async fn try_auth(
|
||||
&mut self,
|
||||
selector: &AuthSelector,
|
||||
|
@ -1119,9 +1014,7 @@ impl ServerSession {
|
|||
target_name,
|
||||
} => {
|
||||
let cp = self.services.config_provider.clone();
|
||||
|
||||
let state_arc = self.get_auth_state(username).await?;
|
||||
let mut state = state_arc.lock().await;
|
||||
let state = self.get_auth_state(username).await?;
|
||||
|
||||
if let Some(credential) = credential {
|
||||
if cp
|
||||
|
@ -1138,12 +1031,6 @@ impl ServerSession {
|
|||
|
||||
match user_auth_result {
|
||||
AuthResult::Accepted { username } => {
|
||||
self.services
|
||||
.auth_state_store
|
||||
.lock()
|
||||
.await
|
||||
.complete(state.id())
|
||||
.await;
|
||||
let target_auth_result = {
|
||||
self.services
|
||||
.config_provider
|
||||
|
|
|
@ -2,25 +2,22 @@
|
|||
import { faSignOut } from '@fortawesome/free-solid-svg-icons'
|
||||
import { Alert, Spinner } from 'sveltestrap'
|
||||
import Fa from 'svelte-fa'
|
||||
import Router, { push } from 'svelte-spa-router'
|
||||
import { wrap } from 'svelte-spa-router/wrap'
|
||||
import { get } from 'svelte/store'
|
||||
import { api } from 'gateway/lib/api'
|
||||
import { reloadServerInfo, serverInfo } from 'gateway/lib/store'
|
||||
import ThemeSwitcher from 'common/ThemeSwitcher.svelte'
|
||||
import Login from './Login.svelte'
|
||||
import TargetList from './TargetList.svelte'
|
||||
import Logo from 'common/Logo.svelte'
|
||||
|
||||
let redirecting = false
|
||||
let serverInfoPromise = reloadServerInfo()
|
||||
|
||||
async function init () {
|
||||
await serverInfoPromise
|
||||
await reloadServerInfo()
|
||||
}
|
||||
|
||||
async function logout () {
|
||||
await api.logout()
|
||||
await reloadServerInfo()
|
||||
push('/login')
|
||||
}
|
||||
|
||||
function onPageResume () {
|
||||
|
@ -28,36 +25,6 @@ function onPageResume () {
|
|||
init()
|
||||
}
|
||||
|
||||
async function requireLogin (detail) {
|
||||
await serverInfoPromise
|
||||
if (!get(serverInfo)?.username) {
|
||||
let url = detail.location
|
||||
if (detail.querystring) {
|
||||
url += '?' + detail.querystring
|
||||
}
|
||||
push('/login?next=' + encodeURIComponent(url))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
const routes = {
|
||||
'/': wrap({
|
||||
asyncComponent: () => import('./TargetList.svelte'),
|
||||
props: {
|
||||
'on:navigation': () => redirecting = true,
|
||||
},
|
||||
conditions: [requireLogin],
|
||||
}),
|
||||
'/login': wrap({
|
||||
asyncComponent: () => import('./Login.svelte'),
|
||||
}),
|
||||
'/login/:stateId': wrap({
|
||||
asyncComponent: () => import('./OutOfBandAuth.svelte'),
|
||||
conditions: [requireLogin],
|
||||
}),
|
||||
}
|
||||
|
||||
init()
|
||||
</script>
|
||||
|
||||
|
@ -71,9 +38,9 @@ init()
|
|||
<Spinner />
|
||||
{:else}
|
||||
<div class="d-flex align-items-center mt-5 mb-5">
|
||||
<a class="logo" href="/@warpgate">
|
||||
<div class="logo">
|
||||
<Logo />
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{#if $serverInfo?.username}
|
||||
<div class="ms-auto">
|
||||
|
@ -89,7 +56,12 @@ init()
|
|||
</div>
|
||||
|
||||
<main>
|
||||
<Router {routes}/>
|
||||
{#if $serverInfo?.username}
|
||||
<TargetList
|
||||
on:navigation={() => redirecting = true} />
|
||||
{:else}
|
||||
<Login />
|
||||
{/if}
|
||||
</main>
|
||||
|
||||
<footer class="mt-5">
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
<script lang="ts">
|
||||
import { get } from 'svelte/store'
|
||||
import { querystring, replace } from 'svelte-spa-router'
|
||||
import { replace } from 'svelte-spa-router'
|
||||
import { Alert, FormGroup, Spinner } from 'sveltestrap'
|
||||
import Fa from 'svelte-fa'
|
||||
import { faArrowRight } from '@fortawesome/free-solid-svg-icons'
|
||||
|
@ -10,47 +9,25 @@ import { api, ApiAuthState, LoginFailureResponseFromJSON, SsoProviderDescription
|
|||
import { reloadServerInfo } from 'gateway/lib/store'
|
||||
import AsyncButton from 'common/AsyncButton.svelte'
|
||||
|
||||
export let params: { stateId?: string } = {}
|
||||
|
||||
let error: Error|null = null
|
||||
let username = ''
|
||||
let password = ''
|
||||
let otp = ''
|
||||
let busy = false
|
||||
|
||||
let authState: ApiAuthState|undefined = undefined
|
||||
let authState = ApiAuthState.NotStarted
|
||||
|
||||
let ssoProvidersPromise = api.getSsoProviders()
|
||||
|
||||
const nextURL = new URLSearchParams(get(querystring)).get('next') ?? undefined
|
||||
const nextURL = new URLSearchParams(location.search).get('next')
|
||||
const serverErrorMessage = new URLSearchParams(location.search).get('login_error')
|
||||
|
||||
async function init () {
|
||||
try {
|
||||
authState = (await api.getDefaultAuthState()).state
|
||||
} catch (err) {
|
||||
if (err.status) {
|
||||
const response = err as Response
|
||||
if (response.status === 404) {
|
||||
authState = ApiAuthState.NotStarted
|
||||
}
|
||||
}
|
||||
}
|
||||
authState = (await api.getAuthState()).state
|
||||
continueWithState()
|
||||
}
|
||||
|
||||
function success () {
|
||||
if (nextURL) {
|
||||
replace(nextURL)
|
||||
} else {
|
||||
replace('/')
|
||||
}
|
||||
}
|
||||
|
||||
async function continueWithState () {
|
||||
if (authState === ApiAuthState.Success) {
|
||||
success()
|
||||
}
|
||||
if (authState === ApiAuthState.SsoNeeded) {
|
||||
const providers = await ssoProvidersPromise
|
||||
if (!providers.length) {
|
||||
|
@ -88,15 +65,18 @@ async function _login () {
|
|||
},
|
||||
})
|
||||
}
|
||||
await reloadServerInfo()
|
||||
success()
|
||||
if (nextURL) {
|
||||
location.href = nextURL
|
||||
} else {
|
||||
await reloadServerInfo()
|
||||
replace('/')
|
||||
}
|
||||
} catch (err) {
|
||||
if (err.status) {
|
||||
const response = err as Response
|
||||
if (response.status === 401) {
|
||||
const failure = LoginFailureResponseFromJSON(await response.json())
|
||||
authState = failure.state
|
||||
|
||||
continueWithState()
|
||||
} else {
|
||||
error = new Error(await response.text())
|
||||
|
@ -116,8 +96,8 @@ function onInputKey (event: KeyboardEvent) {
|
|||
async function startSSO (provider: SsoProviderDescription) {
|
||||
busy = true
|
||||
try {
|
||||
const p = await api.startSso({ name: provider.name, next: nextURL })
|
||||
location.href = p.url
|
||||
const params = await api.startSso(provider)
|
||||
location.href = params.url
|
||||
} catch {
|
||||
busy = false
|
||||
}
|
||||
|
@ -135,10 +115,6 @@ async function startSSO (provider: SsoProviderDescription) {
|
|||
<h1>Continue login</h1>
|
||||
{/if}
|
||||
</div>
|
||||
{#if params.stateId}
|
||||
//todo
|
||||
loggin in for auth state id {params.stateId}
|
||||
{/if}
|
||||
{#if authState === ApiAuthState.OtpNeeded}
|
||||
<FormGroup floating label="One-time password">
|
||||
<!-- svelte-ignore a11y-autofocus -->
|
||||
|
|
|
@ -1,65 +0,0 @@
|
|||
<script lang="ts">
|
||||
import { Alert, Spinner } from 'sveltestrap'
|
||||
|
||||
import { api, ApiAuthState, AuthStateResponseInternal } from 'gateway/lib/api'
|
||||
import AsyncButton from 'common/AsyncButton.svelte'
|
||||
|
||||
export let params: { stateId: string }
|
||||
let authState: AuthStateResponseInternal
|
||||
|
||||
async function reload () {
|
||||
authState = await api.getAuthState({ id: params.stateId })
|
||||
}
|
||||
|
||||
async function init () {
|
||||
await reload()
|
||||
}
|
||||
|
||||
async function approve () {
|
||||
api.approveAuth({ id: params.stateId })
|
||||
await reload()
|
||||
}
|
||||
|
||||
async function reject () {
|
||||
api.rejectAuth({ id: params.stateId })
|
||||
await reload()
|
||||
}
|
||||
</script>
|
||||
|
||||
{#await init()}
|
||||
<Spinner />
|
||||
{:then}
|
||||
<div class="page-summary-bar">
|
||||
<h1>Authorization request</h1>
|
||||
</div>
|
||||
|
||||
<p>Authorize this {authState.protocol} session?</p>
|
||||
|
||||
{#if authState.state === ApiAuthState.Success}
|
||||
<Alert color="success">
|
||||
Approved
|
||||
</Alert>
|
||||
{:else if authState.state === ApiAuthState.Failed}
|
||||
<Alert color="danger">
|
||||
Rejected
|
||||
</Alert>
|
||||
{:else}
|
||||
<div class="d-flex">
|
||||
<AsyncButton
|
||||
color="primary"
|
||||
class="d-flex align-items-center ms-auto"
|
||||
click={approve}
|
||||
>
|
||||
Authorize
|
||||
</AsyncButton>
|
||||
<AsyncButton
|
||||
outline
|
||||
color="secondary"
|
||||
class="d-flex align-items-center ms-2"
|
||||
click={reject}
|
||||
>
|
||||
Reject
|
||||
</AsyncButton>
|
||||
</div>
|
||||
{/if}
|
||||
{/await}
|
|
@ -71,16 +71,6 @@
|
|||
"operationId": "otpLogin"
|
||||
}
|
||||
},
|
||||
"/auth/logout": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": ""
|
||||
}
|
||||
},
|
||||
"operationId": "logout"
|
||||
}
|
||||
},
|
||||
"/auth/state": {
|
||||
"get": {
|
||||
"responses": {
|
||||
|
@ -93,108 +83,19 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": ""
|
||||
}
|
||||
},
|
||||
"operationId": "getDefaultAuthState"
|
||||
"operationId": "getAuthState"
|
||||
}
|
||||
},
|
||||
"/auth/state/{id}": {
|
||||
"get": {
|
||||
"parameters": [
|
||||
{
|
||||
"name": "id",
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"deprecated": false
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/AuthStateResponseInternal"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": ""
|
||||
}
|
||||
},
|
||||
"operationId": "get_auth_state"
|
||||
}
|
||||
},
|
||||
"/auth/state/{id}/approve": {
|
||||
"/auth/logout": {
|
||||
"post": {
|
||||
"parameters": [
|
||||
{
|
||||
"name": "id",
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"deprecated": false
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/AuthStateResponseInternal"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"201": {
|
||||
"description": ""
|
||||
}
|
||||
},
|
||||
"operationId": "approve_auth"
|
||||
}
|
||||
},
|
||||
"/auth/state/{id}/reject": {
|
||||
"post": {
|
||||
"parameters": [
|
||||
{
|
||||
"name": "id",
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"deprecated": false
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/AuthStateResponseInternal"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": ""
|
||||
}
|
||||
},
|
||||
"operationId": "reject_auth"
|
||||
"operationId": "logout"
|
||||
}
|
||||
},
|
||||
"/info": {
|
||||
|
@ -286,15 +187,6 @@
|
|||
"in": "path",
|
||||
"required": true,
|
||||
"deprecated": false
|
||||
},
|
||||
{
|
||||
"name": "next",
|
||||
"schema": {
|
||||
"type": "string"
|
||||
},
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"deprecated": false
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
|
@ -326,21 +218,15 @@
|
|||
"PasswordNeeded",
|
||||
"OtpNeeded",
|
||||
"SsoNeeded",
|
||||
"WebUserApprovalNeeded",
|
||||
"PublicKeyNeeded",
|
||||
"Success"
|
||||
]
|
||||
},
|
||||
"AuthStateResponseInternal": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"protocol",
|
||||
"state"
|
||||
],
|
||||
"properties": {
|
||||
"protocol": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"$ref": "#/components/schemas/ApiAuthState"
|
||||
}
|
||||
|
|
|
@ -30,9 +30,6 @@ pub(crate) async fn command(cli: &crate::Cli) -> Result<()> {
|
|||
.await
|
||||
.with_context(|| "Checking MySQL key".to_string())?;
|
||||
}
|
||||
if !config.store.sso_providers.is_empty() && config.store.external_host.is_none() {
|
||||
anyhow::bail!("SSO requires the external_host config option");
|
||||
}
|
||||
info!("No problems found");
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -77,6 +77,7 @@ fn check_and_migrate_config(store: &mut serde_yaml::Value) {
|
|||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn watch_config<P: AsRef<Path> + Send + 'static>(
|
||||
path: P,
|
||||
config: Arc<Mutex<WarpgateConfig>>,
|
||||
|
|
Loading…
Reference in a new issue