mirror of
https://github.com/warp-tech/warpgate.git
synced 2025-09-12 01:24:24 +08:00
fixed #406 - Apple ID SSO
This commit is contained in:
parent
5a7c39c4cb
commit
fffd799a5a
6 changed files with 144 additions and 43 deletions
54
Cargo.lock
generated
54
Cargo.lock
generated
|
@ -1001,21 +1001,6 @@ dependencies = [
|
||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "dhat"
|
|
||||||
version = "0.3.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "0684eaa19a59be283a6f99369917b679bd4d1d06604b2eb2e2f87b4bbd67668d"
|
|
||||||
dependencies = [
|
|
||||||
"backtrace",
|
|
||||||
"lazy_static",
|
|
||||||
"parking_lot 0.12.0",
|
|
||||||
"rustc-hash",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"thousands",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dialoguer"
|
name = "dialoguer"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
|
@ -1830,6 +1815,20 @@ dependencies = [
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jsonwebtoken"
|
||||||
|
version = "8.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1aa4b4af834c6cfd35d8763d359661b90f2e45d8f750a0849156c7f4671af09c"
|
||||||
|
dependencies = [
|
||||||
|
"base64",
|
||||||
|
"pem",
|
||||||
|
"ring",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"simple_asn1",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "kqueue"
|
name = "kqueue"
|
||||||
version = "1.0.5"
|
version = "1.0.5"
|
||||||
|
@ -2315,9 +2314,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "oauth2"
|
name = "oauth2"
|
||||||
version = "4.2.3"
|
version = "4.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6d62c436394991641b970a92e23e8eeb4eb9bca74af4f5badc53bcd568daadbd"
|
checksum = "eeaf26a72311c087f8c5ba617c96fac67a5c04f430e716ac8d8ab2de62e23368"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64",
|
"base64",
|
||||||
"chrono",
|
"chrono",
|
||||||
|
@ -3784,6 +3783,18 @@ version = "1.6.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f0ea32af43239f0d353a7dd75a22d94c329c8cdaafdcb4c1c1335aa10c298a4a"
|
checksum = "f0ea32af43239f0d353a7dd75a22d94c329c8cdaafdcb4c1c1335aa10c298a4a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "simple_asn1"
|
||||||
|
version = "0.6.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085"
|
||||||
|
dependencies = [
|
||||||
|
"num-bigint",
|
||||||
|
"num-traits",
|
||||||
|
"thiserror",
|
||||||
|
"time 0.3.15",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "siphasher"
|
name = "siphasher"
|
||||||
version = "0.3.10"
|
version = "0.3.10"
|
||||||
|
@ -4087,12 +4098,6 @@ dependencies = [
|
||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "thousands"
|
|
||||||
version = "0.2.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "3bf63baf9f5039dadc247375c29eb13706706cfde997d0330d05aa63a77d8820"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thread_local"
|
name = "thread_local"
|
||||||
version = "1.1.4"
|
version = "1.1.4"
|
||||||
|
@ -4654,7 +4659,6 @@ dependencies = [
|
||||||
"console",
|
"console",
|
||||||
"console-subscriber",
|
"console-subscriber",
|
||||||
"data-encoding",
|
"data-encoding",
|
||||||
"dhat",
|
|
||||||
"dialoguer",
|
"dialoguer",
|
||||||
"futures",
|
"futures",
|
||||||
"notify",
|
"notify",
|
||||||
|
@ -4909,6 +4913,8 @@ name = "warpgate-sso"
|
||||||
version = "0.6.5"
|
version = "0.6.5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
|
"data-encoding",
|
||||||
|
"jsonwebtoken",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"openidconnect",
|
"openidconnect",
|
||||||
"serde",
|
"serde",
|
||||||
|
|
|
@ -9,7 +9,9 @@ bytes = "1.2"
|
||||||
thiserror = "1.0"
|
thiserror = "1.0"
|
||||||
tokio = { version = "1.20", features = ["tracing", "macros"] }
|
tokio = { version = "1.20", features = ["tracing", "macros"] }
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
openidconnect = { version = "2.4", features = ["reqwest", "rustls-tls"] }
|
openidconnect = { version = "2.4", features = ["reqwest", "rustls-tls", "accept-string-booleans"] }
|
||||||
serde = "1.0"
|
serde = "1.0"
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
once_cell = "1.14"
|
once_cell = "1.14"
|
||||||
|
jsonwebtoken = "8"
|
||||||
|
data-encoding = "2.3"
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
use data_encoding::BASE64;
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use openidconnect::{ClientId, ClientSecret, IssuerUrl};
|
use openidconnect::{AuthType, ClientId, ClientSecret, IssuerUrl};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::SsoError;
|
use crate::SsoError;
|
||||||
|
@ -42,6 +44,8 @@ pub enum SsoInternalProviderConfig {
|
||||||
Apple {
|
Apple {
|
||||||
client_id: ClientId,
|
client_id: ClientId,
|
||||||
client_secret: ClientSecret,
|
client_secret: ClientSecret,
|
||||||
|
key_id: String,
|
||||||
|
team_id: String,
|
||||||
},
|
},
|
||||||
#[serde(rename = "azure")]
|
#[serde(rename = "azure")]
|
||||||
Azure {
|
Azure {
|
||||||
|
@ -58,6 +62,15 @@ pub enum SsoInternalProviderConfig {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct AppleIDClaims<'a> {
|
||||||
|
sub: &'a str,
|
||||||
|
aud: &'a str,
|
||||||
|
exp: usize,
|
||||||
|
nbf: usize,
|
||||||
|
iss: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
impl SsoInternalProviderConfig {
|
impl SsoInternalProviderConfig {
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn label(&self) -> &'static str {
|
pub fn label(&self) -> &'static str {
|
||||||
|
@ -80,13 +93,53 @@ impl SsoInternalProviderConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn client_secret(&self) -> &ClientSecret {
|
pub fn client_secret(&self) -> Result<ClientSecret, SsoError> {
|
||||||
match self {
|
Ok(match self {
|
||||||
SsoInternalProviderConfig::Google { client_secret, .. }
|
SsoInternalProviderConfig::Google { client_secret, .. }
|
||||||
| SsoInternalProviderConfig::Apple { client_secret, .. }
|
|
||||||
| SsoInternalProviderConfig::Azure { client_secret, .. }
|
| SsoInternalProviderConfig::Azure { client_secret, .. }
|
||||||
| SsoInternalProviderConfig::Custom { client_secret, .. } => client_secret,
|
| SsoInternalProviderConfig::Custom { client_secret, .. } => client_secret.clone(),
|
||||||
}
|
SsoInternalProviderConfig::Apple {
|
||||||
|
client_secret,
|
||||||
|
client_id,
|
||||||
|
key_id,
|
||||||
|
team_id,
|
||||||
|
} => {
|
||||||
|
let key_content =
|
||||||
|
BASE64
|
||||||
|
.decode(client_secret.secret().as_bytes())
|
||||||
|
.map_err(|e| {
|
||||||
|
SsoError::ConfigError(format!(
|
||||||
|
"could not decode base64 client_secret: {e}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
let key = jsonwebtoken::EncodingKey::from_ec_pem(&key_content).map_err(|e| {
|
||||||
|
SsoError::ConfigError(format!(
|
||||||
|
"could not parse client_secret as a private key: {e}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
let mut header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::ES256);
|
||||||
|
header.kid = Some(key_id.into());
|
||||||
|
|
||||||
|
ClientSecret::new(jsonwebtoken::encode(
|
||||||
|
&header,
|
||||||
|
&AppleIDClaims {
|
||||||
|
aud: &APPLE_ISSUER_URL,
|
||||||
|
sub: client_id,
|
||||||
|
exp: SystemTime::now()
|
||||||
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs() as usize
|
||||||
|
+ 600,
|
||||||
|
nbf: SystemTime::now()
|
||||||
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs() as usize,
|
||||||
|
iss: team_id,
|
||||||
|
},
|
||||||
|
&key,
|
||||||
|
)?)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -104,10 +157,11 @@ impl SsoInternalProviderConfig {
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn scopes(&self) -> Vec<String> {
|
pub fn scopes(&self) -> Vec<String> {
|
||||||
match self {
|
match self {
|
||||||
SsoInternalProviderConfig::Google { .. }
|
SsoInternalProviderConfig::Google { .. } | SsoInternalProviderConfig::Azure { .. } => {
|
||||||
| SsoInternalProviderConfig::Apple { .. }
|
vec!["email".to_string()]
|
||||||
| SsoInternalProviderConfig::Azure { .. } => vec!["email".to_string()],
|
}
|
||||||
SsoInternalProviderConfig::Custom { scopes, .. } => scopes.clone(),
|
SsoInternalProviderConfig::Custom { scopes, .. } => scopes.clone(),
|
||||||
|
SsoInternalProviderConfig::Apple { .. } => vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,4 +178,24 @@ impl SsoInternalProviderConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn auth_type(&self) -> AuthType {
|
||||||
|
match self {
|
||||||
|
SsoInternalProviderConfig::Google { .. }
|
||||||
|
| SsoInternalProviderConfig::Custom { .. }
|
||||||
|
| SsoInternalProviderConfig::Azure { .. } => AuthType::BasicAuth,
|
||||||
|
SsoInternalProviderConfig::Apple { .. } => AuthType::RequestBody,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn needs_pkce_verifier(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
SsoInternalProviderConfig::Google { .. }
|
||||||
|
| SsoInternalProviderConfig::Custom { .. }
|
||||||
|
| SsoInternalProviderConfig::Azure { .. } => true,
|
||||||
|
SsoInternalProviderConfig::Apple { .. } => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,8 @@ pub enum SsoError {
|
||||||
Mitm,
|
Mitm,
|
||||||
#[error("config parse error: {0}")]
|
#[error("config parse error: {0}")]
|
||||||
UrlParse(#[from] openidconnect::url::ParseError),
|
UrlParse(#[from] openidconnect::url::ParseError),
|
||||||
|
#[error("config error: {0}")]
|
||||||
|
ConfigError(String),
|
||||||
#[error("provider discovery error: {0}")]
|
#[error("provider discovery error: {0}")]
|
||||||
Discovery(String),
|
Discovery(String),
|
||||||
#[error("code verification error: {0}")]
|
#[error("code verification error: {0}")]
|
||||||
|
@ -20,6 +22,8 @@ pub enum SsoError {
|
||||||
Signing(#[from] SigningError),
|
Signing(#[from] SigningError),
|
||||||
#[error("I/O: {0}")]
|
#[error("I/O: {0}")]
|
||||||
Io(#[from] std::io::Error),
|
Io(#[from] std::io::Error),
|
||||||
|
#[error("JWT error: {0}")]
|
||||||
|
Jwt(#[from] jsonwebtoken::errors::Error),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Other(Box<dyn Error + Send + Sync>),
|
Other(Box<dyn Error + Send + Sync>),
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ pub struct SsoLoginRequest {
|
||||||
pub(crate) csrf_token: CsrfToken,
|
pub(crate) csrf_token: CsrfToken,
|
||||||
pub(crate) nonce: Nonce,
|
pub(crate) nonce: Nonce,
|
||||||
pub(crate) redirect_url: RedirectUrl,
|
pub(crate) redirect_url: RedirectUrl,
|
||||||
pub(crate) pkce_verifier: PkceCodeVerifier,
|
pub(crate) pkce_verifier: Option<PkceCodeVerifier>,
|
||||||
pub(crate) config: SsoInternalProviderConfig,
|
pub(crate) config: SsoInternalProviderConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,15 +32,23 @@ impl SsoLoginRequest {
|
||||||
.await?
|
.await?
|
||||||
.set_redirect_uri(self.redirect_url.clone());
|
.set_redirect_uri(self.redirect_url.clone());
|
||||||
|
|
||||||
let token_response = client
|
let mut req = client.exchange_code(AuthorizationCode::new(code));
|
||||||
.exchange_code(AuthorizationCode::new(code))
|
if let Some(verifier) = self.pkce_verifier {
|
||||||
.set_pkce_verifier(self.pkce_verifier)
|
req = req.set_pkce_verifier(verifier);
|
||||||
|
}
|
||||||
|
|
||||||
|
let token_response = req
|
||||||
.request_async(async_http_client)
|
.request_async(async_http_client)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| match e {
|
.map_err(|e| match e {
|
||||||
RequestTokenError::ServerResponse(response) => {
|
RequestTokenError::ServerResponse(response) => {
|
||||||
SsoError::Verification(response.error().to_string())
|
SsoError::Verification(response.error().to_string())
|
||||||
}
|
}
|
||||||
|
RequestTokenError::Parse(err, path) => SsoError::Verification(format!(
|
||||||
|
"Parse error: {:?} / {:?}",
|
||||||
|
err,
|
||||||
|
String::from_utf8_lossy(&path)
|
||||||
|
)),
|
||||||
e => SsoError::Verification(format!("{e}")),
|
e => SsoError::Verification(format!("{e}")),
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
|
|
@ -24,8 +24,9 @@ pub async fn make_client(config: &SsoInternalProviderConfig) -> Result<CoreClien
|
||||||
Ok(CoreClient::from_provider_metadata(
|
Ok(CoreClient::from_provider_metadata(
|
||||||
metadata,
|
metadata,
|
||||||
config.client_id().clone(),
|
config.client_id().clone(),
|
||||||
Some(config.client_secret().clone()),
|
Some(config.client_secret()?),
|
||||||
))
|
)
|
||||||
|
.set_auth_type(config.auth_type()))
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SsoClient {
|
impl SsoClient {
|
||||||
|
@ -34,8 +35,6 @@ impl SsoClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn start_login(&self, redirect_url: String) -> Result<SsoLoginRequest, SsoError> {
|
pub async fn start_login(&self, redirect_url: String) -> Result<SsoLoginRequest, SsoError> {
|
||||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
|
||||||
|
|
||||||
let redirect_url = RedirectUrl::new(redirect_url)?;
|
let redirect_url = RedirectUrl::new(redirect_url)?;
|
||||||
let client = make_client(&self.config).await?;
|
let client = make_client(&self.config).await?;
|
||||||
let mut auth_req = client
|
let mut auth_req = client
|
||||||
|
@ -54,7 +53,15 @@ impl SsoClient {
|
||||||
auth_req = auth_req.add_scope(Scope::new(scope.to_string()));
|
auth_req = auth_req.add_scope(Scope::new(scope.to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let (auth_url, csrf_token, nonce) = auth_req.set_pkce_challenge(pkce_challenge).url();
|
let pkce_verifier = if self.config.needs_pkce_verifier() {
|
||||||
|
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||||
|
auth_req = auth_req.set_pkce_challenge(pkce_challenge);
|
||||||
|
Some(pkce_verifier)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let (auth_url, csrf_token, nonce) = auth_req.url();
|
||||||
|
|
||||||
Ok(SsoLoginRequest {
|
Ok(SsoLoginRequest {
|
||||||
auth_url,
|
auth_url,
|
||||||
|
|
Loading…
Add table
Reference in a new issue