RFC7591 OAuth dynamic client registration + OpenID Connect Dynamic Client Registration (closes #136 closes #4)
Some checks are pending
trivy / Check (push) Waiting to run

This commit is contained in:
mdecimus 2024-10-01 10:35:35 +02:00
parent 6a5f963b43
commit 200d8d7c45
22 changed files with 619 additions and 108 deletions

View file

@ -36,6 +36,9 @@ pub struct OAuthConfig {
pub oauth_expiry_refresh_token_renew: u64,
pub oauth_max_auth_attempts: u32,
pub allow_anonymous_client_registration: bool,
pub require_client_authentication: bool,
pub oidc_expiry_id_token: u64,
pub oidc_signing_secret: Secret,
pub oidc_signature_algorithm: SignatureAlgorithm,
@ -179,6 +182,12 @@ impl OAuthConfig {
.property_or_default::<Duration>("oauth.oidc.expiry.id-token", "15m")
.unwrap_or_else(|| Duration::from_secs(15 * 60))
.as_secs(),
allow_anonymous_client_registration: config
.property_or_default("oauth.client-registration.anonymous", "false")
.unwrap_or(false),
require_client_authentication: config
.property_or_default("oauth.client-registration.required", "false")
.unwrap_or(true),
oidc_signing_secret,
oidc_signature_algorithm,
oidc_jwks,
@ -197,6 +206,8 @@ impl Default for OAuthConfig {
oauth_expiry_refresh_token_renew: Default::default(),
oauth_max_auth_attempts: Default::default(),
oidc_expiry_id_token: Default::default(),
allow_anonymous_client_registration: Default::default(),
require_client_authentication: Default::default(),
oidc_signing_secret: Secret::Bytes("secret".to_string().into_bytes()),
oidc_signature_algorithm: SignatureAlgorithm::HS256,
oidc_jwks: Resource {

View file

@ -8,6 +8,7 @@ pub mod config;
pub mod crypto;
pub mod introspect;
pub mod oidc;
pub mod registration;
pub mod token;
pub const DEVICE_CODE_LEN: usize = 40;

View file

@ -0,0 +1,181 @@
/*
* SPDX-FileCopyrightText: 2020 Stalwart Labs Ltd <hello@stalw.art>
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
*/
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Serialize, Deserialize, Debug, Default)]
#[serde(rename_all = "snake_case")]
pub struct ClientRegistrationRequest {
pub redirect_uris: Vec<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub response_types: Vec<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub grant_types: Vec<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub application_type: Option<ApplicationType>,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub contacts: Vec<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub client_name: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub logo_uri: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub client_uri: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub policy_uri: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub tos_uri: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks_uri: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks: Option<serde_json::Value>, // Using serde_json::Value for flexibility
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub sector_identifier_uri: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub subject_type: Option<SubjectType>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub id_token_signed_response_alg: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub id_token_encrypted_response_alg: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub id_token_encrypted_response_enc: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_signed_response_alg: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_encrypted_response_alg: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_encrypted_response_enc: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub request_object_signing_alg: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub request_object_encryption_alg: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub request_object_encryption_enc: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_method: Option<TokenEndpointAuthMethod>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_signing_alg: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub default_max_age: Option<u64>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub require_auth_time: Option<bool>,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub default_acr_values: Vec<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub initiate_login_uri: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub request_uris: Vec<String>,
#[serde(flatten)]
#[serde(skip_serializing_if = "HashMap::is_empty")]
pub additional_fields: HashMap<String, serde_json::Value>,
}
#[derive(Serialize, Deserialize, Debug, Default)]
#[serde(rename_all = "snake_case")]
pub struct ClientRegistrationResponse {
// Required fields
pub client_id: String,
// Optional fields specific to the response
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub registration_access_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub registration_client_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id_issued_at: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret_expires_at: Option<u64>,
// Echo back the request
#[serde(flatten)]
pub request: ClientRegistrationRequest,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
pub enum ApplicationType {
Web,
Native,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
pub enum SubjectType {
Pairwise,
Public,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "snake_case")]
pub enum TokenEndpointAuthMethod {
ClientSecretPost,
ClientSecretBasic,
ClientSecretJwt,
PrivateKeyJwt,
None,
}

View file

@ -15,7 +15,9 @@ use store::{
};
use trc::AddContext;
use crate::{Permission, Principal, QueryBy, Type, ROLE_ADMIN, ROLE_TENANT_ADMIN, ROLE_USER};
use crate::{
Permission, Principal, QueryBy, Type, MAX_TYPE_ID, ROLE_ADMIN, ROLE_TENANT_ADMIN, ROLE_USER,
};
use super::{
lookup::DirectoryStore, PrincipalAction, PrincipalField, PrincipalInfo, PrincipalUpdate,
@ -271,16 +273,7 @@ impl ManageDirectory for Store {
principal.set(PrincipalField::Tenant, tenant_id);
if matches!(
principal.typ,
Type::Individual
| Type::Group
| Type::List
| Type::Role
| Type::Location
| Type::Resource
| Type::Other
) {
if !matches!(principal.typ, Type::Tenant | Type::Domain) {
if let Some(domain) = name.split('@').nth(1) {
if self
.get_principal_info(domain)
@ -513,6 +506,7 @@ impl ManageDirectory for Store {
Type::Other,
Type::Location,
Type::Domain,
Type::ApiKey,
],
&[PrincipalField::Name],
0,
@ -771,7 +765,12 @@ impl ManageDirectory for Store {
Type::Other,
][..],
Type::List => &[Type::Individual, Type::Group][..],
Type::Other | Type::Domain | Type::Tenant | Type::Individual => &[][..],
Type::Other
| Type::Domain
| Type::Tenant
| Type::Individual
| Type::ApiKey
| Type::OauthClient => &[][..],
Type::Role => &[Type::Role][..],
};
let mut valid_domains = AHashSet::new();
@ -784,16 +783,7 @@ impl ManageDirectory for Store {
let new_name = new_name.to_lowercase();
if principal.inner.name() != new_name {
if tenant_id.is_some()
&& matches!(
principal.inner.typ,
Type::Individual
| Type::Group
| Type::List
| Type::Role
| Type::Location
| Type::Resource
| Type::Other
)
&& !matches!(principal.inner.typ, Type::Tenant | Type::Domain)
{
if let Some(domain) = new_name.split('@').nth(1) {
if self
@ -978,7 +968,7 @@ impl ManageDirectory for Store {
PrincipalField::Quota,
PrincipalValue::IntegerList(quotas),
) if matches!(principal.inner.typ, Type::Tenant)
&& quotas.len() <= (Type::Role as usize + 2) =>
&& quotas.len() <= (MAX_TYPE_ID + 2) =>
{
principal.inner.set(PrincipalField::Quota, quotas);
}

View file

@ -408,6 +408,7 @@ pub enum PrincipalField {
EnabledPermissions,
DisabledPermissions,
Picture,
Urls,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@ -486,6 +487,7 @@ impl PrincipalField {
PrincipalField::DisabledPermissions => 12,
PrincipalField::UsedQuota => 13,
PrincipalField::Picture => 14,
PrincipalField::Urls => 15,
}
}
@ -506,6 +508,7 @@ impl PrincipalField {
12 => Some(PrincipalField::DisabledPermissions),
13 => Some(PrincipalField::UsedQuota),
14 => Some(PrincipalField::Picture),
15 => Some(PrincipalField::Urls),
_ => None,
}
}
@ -527,6 +530,7 @@ impl PrincipalField {
PrincipalField::EnabledPermissions => "enabledPermissions",
PrincipalField::DisabledPermissions => "disabledPermissions",
PrincipalField::Picture => "picture",
PrincipalField::Urls => "urls",
}
}
@ -547,6 +551,7 @@ impl PrincipalField {
"enabledPermissions" => Some(PrincipalField::EnabledPermissions),
"disabledPermissions" => Some(PrincipalField::DisabledPermissions),
"picture" => Some(PrincipalField::Picture),
"urls" => Some(PrincipalField::Urls),
_ => None,
}
}

View file

@ -183,6 +183,18 @@ impl Permission {
Permission::SieveRenameScript => "Rename Sieve scripts",
Permission::SieveCheckScript => "Validate Sieve scripts",
Permission::SieveHaveSpace => "Check available space for Sieve scripts",
Permission::OauthClientRegistration => "Register OAuth clients",
Permission::OauthClientOverride => "Override OAuth client settings",
Permission::ApiKeyList => "View API keys",
Permission::ApiKeyGet => "Retrieve specific API keys",
Permission::ApiKeyCreate => "Create new API keys",
Permission::ApiKeyUpdate => "Modify API keys",
Permission::ApiKeyDelete => "Remove API keys",
Permission::OauthClientList => "View OAuth clients",
Permission::OauthClientGet => "Retrieve specific OAuth clients",
Permission::OauthClientCreate => "Create new OAuth clients",
Permission::OauthClientUpdate => "Modify OAuth clients",
Permission::OauthClientDelete => "Remove OAuth clients",
}
}
}

View file

@ -591,6 +591,8 @@ impl Type {
Self::Tenant => "tenant",
Self::Role => "role",
Self::Domain => "domain",
Self::ApiKey => "api-key",
Self::OauthClient => "oauth-client",
}
}
@ -605,6 +607,8 @@ impl Type {
Self::Other => "Other",
Self::Role => "Role",
Self::Domain => "Domain",
Self::ApiKey => "API Key",
Self::OauthClient => "OAuth Client",
}
}
@ -619,6 +623,8 @@ impl Type {
"superuser" => Some(Type::Individual), // legacy
"role" => Some(Type::Role),
"domain" => Some(Type::Domain),
"api-key" => Some(Type::ApiKey),
"oauth-client" => Some(Type::OauthClient),
_ => None,
}
}
@ -635,6 +641,8 @@ impl Type {
7 => Type::Domain,
8 => Type::Tenant,
9 => Type::Role,
10 => Type::ApiKey,
11 => Type::OauthClient,
_ => Type::Other,
}
}
@ -835,18 +843,17 @@ impl<'de> serde::Deserialize<'de> for Principal {
| PrincipalField::Roles
| PrincipalField::Lists
| PrincipalField::EnabledPermissions
| PrincipalField::DisabledPermissions => {
match map.next_value::<StringOrMany>()? {
StringOrMany::One(v) => PrincipalValue::StringList(vec![v]),
StringOrMany::Many(v) => {
if !v.is_empty() {
PrincipalValue::StringList(v)
} else {
continue;
}
| PrincipalField::DisabledPermissions
| PrincipalField::Urls => match map.next_value::<StringOrMany>()? {
StringOrMany::One(v) => PrincipalValue::StringList(vec![v]),
StringOrMany::Many(v) => {
if !v.is_empty() {
PrincipalValue::StringList(v)
} else {
continue;
}
}
}
},
PrincipalField::UsedQuota => {
// consume and ignore
map.next_value::<IgnoredAny>()?;

View file

@ -52,8 +52,12 @@ pub enum Type {
Domain = 7,
Tenant = 8,
Role = 9,
ApiKey = 10,
OauthClient = 11,
}
pub const MAX_TYPE_ID: usize = 11;
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, EnumMethods,
)]
@ -240,6 +244,24 @@ pub enum Permission {
SieveRenameScript,
SieveCheckScript,
SieveHaveSpace,
// API keys
ApiKeyList,
ApiKeyGet,
ApiKeyCreate,
ApiKeyUpdate,
ApiKeyDelete,
// OAuth clients
OauthClientList,
OauthClientGet,
OauthClientCreate,
OauthClientUpdate,
OauthClientDelete,
// OAuth client registration
OauthClientRegistration,
OauthClientOverride,
// WARNING: add new ids at the end (TODO: use static ids)
}

View file

@ -37,7 +37,10 @@ use crate::{
api::management::enterprise::telemetry::TelemetryApi,
auth::{
authenticate::{Authenticator, HttpHeaders},
oauth::{auth::OAuthApiHandler, openid::OpenIdHandler, token::TokenHandler, FormData},
oauth::{
auth::OAuthApiHandler, openid::OpenIdHandler, registration::ClientRegistrationHandler,
token::TokenHandler, FormData,
},
rate_limit::RateLimiter,
},
blob::{download::BlobDownload, upload::BlobUpload, DownloadResponse, UploadResponse},
@ -99,7 +102,7 @@ impl ParseHttp for Server {
("", &Method::POST) => {
// Authenticate request
let (_in_flight, access_token) =
self.authenticate_headers(&req, &session).await?;
self.authenticate_headers(&req, &session, false).await?;
let request = fetch_body(
&mut req,
@ -128,7 +131,7 @@ impl ParseHttp for Server {
("download", &Method::GET) => {
// Authenticate request
let (_in_flight, access_token) =
self.authenticate_headers(&req, &session).await?;
self.authenticate_headers(&req, &session, false).await?;
if let (Some(_), Some(blob_id), Some(name)) = (
path.next().and_then(|p| Id::from_bytes(p.as_bytes())),
@ -157,7 +160,7 @@ impl ParseHttp for Server {
("upload", &Method::POST) => {
// Authenticate request
let (_in_flight, access_token) =
self.authenticate_headers(&req, &session).await?;
self.authenticate_headers(&req, &session, false).await?;
if let Some(account_id) =
path.next().and_then(|p| Id::from_bytes(p.as_bytes()))
@ -192,14 +195,14 @@ impl ParseHttp for Server {
("eventsource", &Method::GET) => {
// Authenticate request
let (_in_flight, access_token) =
self.authenticate_headers(&req, &session).await?;
self.authenticate_headers(&req, &session, false).await?;
return self.handle_event_source(req, access_token).await;
}
("ws", &Method::GET) => {
// Authenticate request
let (_in_flight, access_token) =
self.authenticate_headers(&req, &session).await?;
self.authenticate_headers(&req, &session, false).await?;
return self
.upgrade_websocket_connection(req, access_token, session)
@ -215,7 +218,7 @@ impl ParseHttp for Server {
("jmap", &Method::GET) => {
// Authenticate request
let (_in_flight, access_token) =
self.authenticate_headers(&req, &session).await?;
self.authenticate_headers(&req, &session, false).await?;
return self
.handle_session_resource(ctx.resolve_response_url(self).await, access_token)
@ -286,7 +289,7 @@ impl ParseHttp for Server {
("introspect", &Method::POST) => {
// Authenticate request
let (_in_flight, access_token) =
self.authenticate_headers(&req, &session).await?;
self.authenticate_headers(&req, &session, false).await?;
return self
.handle_token_introspect(&mut req, &access_token, session.session_id)
@ -295,10 +298,15 @@ impl ParseHttp for Server {
("userinfo", &Method::GET) => {
// Authenticate request
let (_in_flight, access_token) =
self.authenticate_headers(&req, &session).await?;
self.authenticate_headers(&req, &session, false).await?;
return self.handle_userinfo_request(&access_token).await;
}
("register", &Method::POST) => {
return self
.handle_oauth_registration_request(&mut req, session)
.await;
}
("jwks.json", &Method::GET) => {
// Limit anonymous requests
self.is_anonymous_allowed(&session.remote_ip).await?;
@ -317,11 +325,10 @@ impl ParseHttp for Server {
}
// Authenticate user
match self.authenticate_headers(&req, &session).await {
match self.authenticate_headers(&req, &session, true).await {
Ok((_, access_token)) => {
let body = fetch_body(&mut req, 1024 * 1024, session.session_id).await;
return self
.handle_api_manage_request(&req, body, access_token, &session)
.handle_api_manage_request(&mut req, access_token, &session)
.await;
}
Err(err) => {

View file

@ -39,7 +39,10 @@ use stores::ManageStore;
use crate::{auth::oauth::auth::OAuthApiHandler, email::crypto::CryptoHandler};
use super::{http::HttpSessionData, HttpRequest, HttpResponse};
use super::{
http::{fetch_body, HttpSessionData},
HttpRequest, HttpResponse,
};
use std::future::Future;
#[derive(Serialize)]
@ -69,8 +72,7 @@ pub enum ManagementApiError<'x> {
pub trait ManagementApi: Sync + Send {
fn handle_api_manage_request(
&self,
req: &HttpRequest,
body: Option<Vec<u8>>,
req: &mut HttpRequest,
access_token: Arc<AccessToken>,
session: &HttpSessionData,
) -> impl Future<Output = trc::Result<HttpResponse>> + Send;
@ -80,11 +82,11 @@ impl ManagementApi for Server {
#[allow(unused_variables)]
async fn handle_api_manage_request(
&self,
req: &HttpRequest,
body: Option<Vec<u8>>,
req: &mut HttpRequest,
access_token: Arc<AccessToken>,
session: &HttpSessionData,
) -> trc::Result<HttpResponse> {
let body = fetch_body(req, 1024 * 1024, session.session_id).await;
let path = req.uri().path().split('/').skip(2).collect::<Vec<_>>();
match path.first().copied().unwrap_or_default() {

View file

@ -95,6 +95,8 @@ impl PrincipalManager for Server {
Type::Domain => Permission::DomainCreate,
Type::Tenant => Permission::TenantCreate,
Type::Role => Permission::RoleCreate,
Type::ApiKey => Permission::ApiKeyCreate,
Type::OauthClient => Permission::OauthClientCreate,
Type::Resource | Type::Location | Type::Other => Permission::PrincipalCreate,
})?;
@ -175,6 +177,8 @@ impl PrincipalManager for Server {
Type::Tenant,
Type::Role,
Type::Other,
Type::ApiKey,
Type::OauthClient,
]
};
for typ in validate_types {
@ -185,6 +189,8 @@ impl PrincipalManager for Server {
Type::Domain => Permission::DomainList,
Type::Tenant => Permission::TenantList,
Type::Role => Permission::RoleList,
Type::ApiKey => Permission::ApiKeyList,
Type::OauthClient => Permission::OauthClientList,
Type::Resource | Type::Location | Type::Other => Permission::PrincipalList,
})?;
}
@ -266,6 +272,8 @@ impl PrincipalManager for Server {
Type::Domain => Permission::DomainGet,
Type::Tenant => Permission::TenantGet,
Type::Role => Permission::RoleGet,
Type::ApiKey => Permission::ApiKeyGet,
Type::OauthClient => Permission::OauthClientGet,
Type::Resource | Type::Location | Type::Other => {
Permission::PrincipalGet
}
@ -301,6 +309,8 @@ impl PrincipalManager for Server {
Type::Domain => Permission::DomainDelete,
Type::Tenant => Permission::TenantDelete,
Type::Role => Permission::RoleDelete,
Type::ApiKey => Permission::ApiKeyDelete,
Type::OauthClient => Permission::OauthClientDelete,
Type::Resource | Type::Location | Type::Other => {
Permission::PrincipalDelete
}
@ -347,6 +357,8 @@ impl PrincipalManager for Server {
Type::Domain => Permission::DomainUpdate,
Type::Tenant => Permission::TenantUpdate,
Type::Role => Permission::RoleUpdate,
Type::ApiKey => Permission::ApiKeyUpdate,
Type::OauthClient => Permission::OauthClientUpdate,
Type::Resource | Type::Location | Type::Other => {
Permission::PrincipalUpdate
}
@ -382,7 +394,8 @@ impl PrincipalManager for Server {
| PrincipalField::Picture
| PrincipalField::MemberOf
| PrincipalField::Members
| PrincipalField::Lists => (),
| PrincipalField::Lists
| PrincipalField::Urls => (),
PrincipalField::Tenant => {
// Tenants are not allowed to change their tenantId
if access_token.tenant.is_some() {

View file

@ -24,6 +24,7 @@ pub trait Authenticator: Sync + Send {
&self,
req: &HttpRequest,
session: &HttpSessionData,
allow_api_access: bool,
) -> impl Future<Output = trc::Result<(InFlight, Arc<AccessToken>)>> + Send;
}
@ -32,6 +33,7 @@ impl Authenticator for Server {
&self,
req: &HttpRequest,
session: &HttpSessionData,
allow_api_access: bool,
) -> trc::Result<(InFlight, Arc<AccessToken>)> {
if let Some((mechanism, token)) = req.authorization() {
let access_token =
@ -43,29 +45,24 @@ impl Authenticator for Server {
self.is_auth_allowed_soft(&session.remote_ip).await?;
// Decode the base64 encoded credentials
if let Some((username, secret)) = base64_decode(token.as_bytes())
.and_then(|token| String::from_utf8(token).ok())
.and_then(|token| {
token.split_once(':').map(|(login, secret)| {
(login.trim().to_lowercase(), secret.to_string())
})
})
{
Credentials::Plain { username, secret }
} else {
return Err(trc::AuthEvent::Error
decode_plain_auth(token).ok_or_else(|| {
trc::AuthEvent::Error
.into_err()
.details("Failed to decode Basic auth request.")
.id(token.to_string())
.caused_by(trc::location!()));
}
.caused_by(trc::location!())
})?
} else if mechanism.eq_ignore_ascii_case("bearer") {
// Enforce anonymous rate limit
self.is_anonymous_allowed(&session.remote_ip).await?;
Credentials::OAuthBearer {
token: token.to_string(),
}
decode_bearer_token(token, allow_api_access).ok_or_else(|| {
trc::AuthEvent::Error
.into_err()
.details("Failed to decode Bearer token.")
.id(token.to_string())
.caused_by(trc::location!())
})?
} else {
// Enforce anonymous rate limit
self.is_anonymous_allowed(&session.remote_ip).await?;
@ -139,3 +136,28 @@ impl HttpHeaders for HttpRequest {
})
}
}
fn decode_plain_auth(token: &str) -> Option<Credentials<String>> {
base64_decode(token.as_bytes())
.and_then(|token| String::from_utf8(token).ok())
.and_then(|token| {
token
.split_once(':')
.map(|(login, secret)| Credentials::Plain {
username: login.trim().to_lowercase(),
secret: secret.to_string(),
})
})
}
fn decode_bearer_token(token: &str, allow_api_access: bool) -> Option<Credentials<String>> {
if allow_api_access {
if let Some(token) = token.strip_prefix("api_") {
return decode_plain_auth(token);
}
}
Some(Credentials::OAuthBearer {
token: token.to_string(),
})
}

View file

@ -39,6 +39,7 @@ pub struct OAuthMetadata {
pub token_endpoint: String,
pub authorization_endpoint: String,
pub device_authorization_endpoint: String,
pub registration_endpoint: String,
pub introspection_endpoint: String,
pub grant_types_supported: Vec<String>,
pub response_types_supported: Vec<String>,
@ -191,7 +192,7 @@ impl OAuthApiHandler for Server {
let client_id = FormData::from_request(req, MAX_POST_LEN, session.session_id)
.await?
.remove("client_id")
.filter(|client_id| client_id.len() < CLIENT_ID_MAX_LEN)
.filter(|client_id| client_id.len() <= CLIENT_ID_MAX_LEN)
.ok_or_else(|| {
trc::ResourceEvent::BadParameters
.into_err()
@ -277,12 +278,14 @@ impl OAuthApiHandler for Server {
Ok(JsonResponse::new(OAuthMetadata {
authorization_endpoint: format!("{base_url}/authorize/code",),
token_endpoint: format!("{base_url}/auth/token"),
device_authorization_endpoint: format!("{base_url}/auth/device"),
introspection_endpoint: format!("{base_url}/auth/introspect"),
registration_endpoint: format!("{base_url}/auth/register"),
grant_types_supported: vec![
"authorization_code".to_string(),
"implicit".to_string(),
"urn:ietf:params:oauth:grant-type:device_code".to_string(),
],
device_authorization_endpoint: format!("{base_url}/auth/device"),
response_types_supported: vec![
"code".to_string(),
"id_token".to_string(),
@ -290,7 +293,6 @@ impl OAuthApiHandler for Server {
"id_token token".to_string(),
],
scopes_supported: vec!["openid".to_string(), "offline_access".to_string()],
introspection_endpoint: format!("{base_url}/auth/introspect"),
issuer: base_url,
})
.into_http_response())

View file

@ -12,6 +12,7 @@ use crate::api::{http::fetch_body, HttpRequest};
pub mod auth;
pub mod openid;
pub mod registration;
pub mod token;
#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]

View file

@ -0,0 +1,156 @@
/*
* SPDX-FileCopyrightText: 2020 Stalwart Labs Ltd <hello@stalw.art>
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
*/
use std::future::Future;
use common::{
auth::oauth::registration::{ClientRegistrationRequest, ClientRegistrationResponse},
Server,
};
use directory::{
backend::internal::{lookup::DirectoryStore, manage::ManageDirectory, PrincipalField},
Permission, Principal, QueryBy, Type,
};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use trc::{AddContext, AuthEvent};
use crate::{
api::{
http::{fetch_body, HttpSessionData, ToHttpResponse},
HttpRequest, HttpResponse, JsonResponse,
},
auth::{authenticate::Authenticator, rate_limit::RateLimiter},
};
use super::ErrorType;
pub trait ClientRegistrationHandler: Sync + Send {
fn handle_oauth_registration_request(
&self,
req: &mut HttpRequest,
session: HttpSessionData,
) -> impl Future<Output = trc::Result<HttpResponse>> + Send;
fn validate_client_registration(
&self,
client_id: &str,
redirect_uri: Option<&str>,
account_id: u32,
) -> impl Future<Output = trc::Result<Option<ErrorType>>> + Send;
}
impl ClientRegistrationHandler for Server {
async fn handle_oauth_registration_request(
&self,
req: &mut HttpRequest,
session: HttpSessionData,
) -> trc::Result<HttpResponse> {
if !self.core.oauth.allow_anonymous_client_registration {
// Authenticate request
let (_, access_token) = self.authenticate_headers(req, &session, true).await?;
// Validate permissions
access_token.assert_has_permission(Permission::OauthClientRegistration)?;
} else {
self.is_anonymous_allowed(&session.remote_ip).await?;
}
// Parse request
let body = fetch_body(req, 20 * 1024, session.session_id).await;
let request = serde_json::from_slice::<ClientRegistrationRequest>(
body.as_deref().unwrap_or_default(),
)
.map_err(|err| {
trc::EventType::Resource(trc::ResourceEvent::BadParameters).from_json_error(err)
})?;
// Generate client ID
let client_id = thread_rng()
.sample_iter(Alphanumeric)
.take(20)
.map(|ch| char::from(ch.to_ascii_lowercase()))
.collect::<String>();
self.store()
.create_principal(
Principal::new(u32::MAX, Type::OauthClient)
.with_field(PrincipalField::Name, client_id.clone())
.with_field(PrincipalField::Urls, request.redirect_uris.clone())
.with_opt_field(PrincipalField::Description, request.client_name.clone())
.with_field(PrincipalField::Emails, request.contacts.clone())
.with_opt_field(PrincipalField::Picture, request.logo_uri.clone()),
None,
)
.await
.caused_by(trc::location!())?;
trc::event!(
Auth(AuthEvent::ClientRegistration),
Id = client_id.to_string(),
RemoteIp = session.remote_ip
);
Ok(JsonResponse::new(ClientRegistrationResponse {
client_id,
request,
..Default::default()
})
.into_http_response())
}
async fn validate_client_registration(
&self,
client_id: &str,
redirect_uri: Option<&str>,
account_id: u32,
) -> trc::Result<Option<ErrorType>> {
if !self.core.oauth.require_client_authentication {
return Ok(None);
}
// Fetch client registration
let found_registration = if let Some(client) = self
.store()
.query(QueryBy::Name(client_id), false)
.await
.caused_by(trc::location!())?
.filter(|p| p.typ() == Type::OauthClient)
{
if let Some(redirect_uri) = redirect_uri {
if client
.get_str_array(PrincipalField::Urls)
.unwrap_or_default()
.iter()
.any(|uri| uri == redirect_uri)
{
return Ok(None);
}
} else {
// Device flow does not require a redirect URI
return Ok(None);
}
true
} else {
false
};
// Check if the account is allowed to override client registration
if self
.get_cached_access_token(account_id)
.await
.caused_by(trc::location!())?
.has_permission(Permission::OauthClientOverride)
{
return Ok(None);
}
Ok(Some(if found_registration {
ErrorType::InvalidClient
} else {
ErrorType::InvalidRequest
}))
}
}

View file

@ -18,7 +18,8 @@ use crate::api::{
};
use super::{
ErrorType, FormData, OAuthCode, OAuthResponse, OAuthStatus, TokenResponse, MAX_POST_LEN,
registration::ClientRegistrationHandler, ErrorType, FormData, OAuthCode, OAuthResponse,
OAuthStatus, TokenResponse, MAX_POST_LEN,
};
pub trait TokenHandler: Sync + Send {
@ -80,23 +81,35 @@ impl TokenHandler for Server {
if client_id != oauth.client_id || redirect_uri != oauth.params {
TokenResponse::error(ErrorType::InvalidClient)
} else if oauth.status == OAuthStatus::Authorized {
// Mark this token as issued
self.core
.storage
.lookup
.key_delete(format!("oauth:{code}").into_bytes())
.await?;
// Validate client id
if let Some(error) = self
.validate_client_registration(
client_id,
redirect_uri.into(),
oauth.account_id,
)
.await?
{
TokenResponse::error(error)
} else {
// Mark this token as issued
self.core
.storage
.lookup
.key_delete(format!("oauth:{code}").into_bytes())
.await?;
// Issue token
self.issue_token(oauth.account_id, &oauth.client_id, issuer, true)
.await
.map(TokenResponse::Granted)
.map_err(|err| {
trc::AuthEvent::Error
.into_err()
.details(err)
.caused_by(trc::location!())
})?
// Issue token
self.issue_token(oauth.account_id, &oauth.client_id, issuer, true)
.await
.map(TokenResponse::Granted)
.map_err(|err| {
trc::AuthEvent::Error
.into_err()
.details(err)
.caused_by(trc::location!())
})?
}
} else {
TokenResponse::error(ErrorType::InvalidGrant)
}
@ -126,15 +139,26 @@ impl TokenHandler for Server {
} else {
match oauth.status {
OAuthStatus::Authorized => {
// Mark this token as issued
self.core
.storage
.lookup
.key_delete(format!("oauth:{device_code}").into_bytes())
.await?;
if let Some(error) = self
.validate_client_registration(client_id, None, oauth.account_id)
.await?
{
TokenResponse::error(error)
} else {
// Mark this token as issued
self.core
.storage
.lookup
.key_delete(format!("oauth:{device_code}").into_bytes())
.await?;
// Issue token
self.issue_token(oauth.account_id, &oauth.client_id, issuer, true)
// Issue token
self.issue_token(
oauth.account_id,
&oauth.client_id,
issuer,
true,
)
.await
.map(TokenResponse::Granted)
.map_err(|err| {
@ -143,6 +167,7 @@ impl TokenHandler for Server {
.details(err)
.caused_by(trc::location!())
})?
}
}
OAuthStatus::Pending => {
TokenResponse::error(ErrorType::AuthorizationPending)

View file

@ -1742,6 +1742,7 @@ impl AuthEvent {
AuthEvent::TooManyAttempts => "Too many authentication attempts",
AuthEvent::Error => "Authentication error",
AuthEvent::TokenExpired => "OAuth token expired",
AuthEvent::ClientRegistration => "OAuth Client registration",
}
}
@ -1753,6 +1754,7 @@ impl AuthEvent {
AuthEvent::TooManyAttempts => "Too many authentication attempts have been made",
AuthEvent::Error => "An error occurred with authentication",
AuthEvent::TokenExpired => "OAuth authentication token has expired",
AuthEvent::ClientRegistration => "OAuth client successfully registered",
}
}
}

View file

@ -229,7 +229,7 @@ impl EventType {
AuthEvent::MissingTotp => Level::Trace,
AuthEvent::TooManyAttempts => Level::Warn,
AuthEvent::Error => Level::Error,
AuthEvent::Success => Level::Info,
AuthEvent::Success | AuthEvent::ClientRegistration => Level::Info,
},
EventType::Config(cause) => match cause {
ConfigEvent::ParseError

View file

@ -926,6 +926,7 @@ pub enum AuthEvent {
TokenExpired,
MissingTotp,
TooManyAttempts,
ClientRegistration,
Error,
}

View file

@ -860,6 +860,7 @@ impl EventType {
EventType::Security(SecurityEvent::Unauthorized) => 552,
EventType::Limit(LimitEvent::TenantQuota) => 553,
EventType::Auth(AuthEvent::TokenExpired) => 554,
EventType::Auth(AuthEvent::ClientRegistration) => 555,
}
}
@ -1460,6 +1461,7 @@ impl EventType {
552 => Some(EventType::Security(SecurityEvent::Unauthorized)),
553 => Some(EventType::Limit(LimitEvent::TenantQuota)),
554 => Some(EventType::Auth(AuthEvent::TokenExpired)),
555 => Some(EventType::Auth(AuthEvent::ClientRegistration)),
_ => None,
}
}

View file

@ -9,7 +9,10 @@ use std::time::{Duration, Instant};
use base64::{engine::general_purpose, Engine};
use biscuit::{jwk::JWKSet, SingleOrMultiple, JWT};
use bytes::Bytes;
use common::auth::oauth::introspect::OAuthIntrospect;
use common::auth::oauth::{
introspect::OAuthIntrospect,
registration::{ClientRegistrationRequest, ClientRegistrationResponse},
};
use imap_proto::ResponseType;
use jmap::auth::oauth::{
auth::OAuthMetadata, openid::OpenIdMetadata, DeviceAuthResponse, ErrorType, OAuthCodeRequest,
@ -20,7 +23,7 @@ use jmap_client::{
mailbox::query::Filter,
};
use jmap_proto::types::id::Id;
use serde::de::DeserializeOwned;
use serde::{de::DeserializeOwned, Serialize};
use store::ahash::AHashMap;
use crate::{
@ -72,6 +75,18 @@ pub async fn test(params: &mut JMAPTest) {
get("https://127.0.0.1:8899/.well-known/openid-configuration").await;
let jwk_set: JWKSet<()> = get(&oidc_metadata.jwks_uri).await;
// Register client
let registration: ClientRegistrationResponse = post_json(
&metadata.registration_endpoint,
None,
&ClientRegistrationRequest {
redirect_uris: vec!["https://localhost".to_string()],
..Default::default()
},
)
.await;
let client_id = registration.client_id;
/*println!("OAuth metadata: {:#?}", metadata);
println!("OpenID metadata: {:#?}", oidc_metadata);
println!("JWKSet: {:#?}", jwk_set);*/
@ -85,7 +100,7 @@ pub async fn test(params: &mut JMAPTest) {
.post::<OAuthCodeResponse>(
"/api/oauth",
&OAuthCodeRequest::Code {
client_id: "OAuthyMcOAuthFace".to_string(),
client_id: client_id.to_string(),
redirect_uri: "https://localhost".to_string().into(),
},
)
@ -106,7 +121,7 @@ pub async fn test(params: &mut JMAPTest) {
error: ErrorType::InvalidClient
}
);
token_params.insert("client_id".to_string(), "OAuthyMcOAuthFace".to_string());
token_params.insert("client_id".to_string(), client_id.to_string());
token_params.insert(
"redirect_uri".to_string(),
"https://some-other.url".to_string(),
@ -147,7 +162,7 @@ pub async fn test(params: &mut JMAPTest) {
assert_eq!(claims.subject, Some(john_int_id.to_string()));
assert_eq!(
claims.audience,
Some(SingleOrMultiple::Single("OAuthyMcOAuthFace".to_string()))
Some(SingleOrMultiple::Single(client_id.to_string()))
);
// Introspect token
@ -159,7 +174,7 @@ pub async fn test(params: &mut JMAPTest) {
.await;
assert_eq!(access_introspect.username.unwrap(), "jdoe@example.com");
assert_eq!(access_introspect.token_type.unwrap(), "bearer");
assert_eq!(access_introspect.client_id.unwrap(), "OAuthyMcOAuthFace");
assert_eq!(access_introspect.client_id.unwrap(), client_id);
assert!(access_introspect.active);
let refresh_introspect = post_with_auth::<OAuthIntrospect>(
&metadata.introspection_endpoint,
@ -168,7 +183,7 @@ pub async fn test(params: &mut JMAPTest) {
)
.await;
assert_eq!(refresh_introspect.username.unwrap(), "jdoe@example.com");
assert_eq!(refresh_introspect.client_id.unwrap(), "OAuthyMcOAuthFace");
assert_eq!(refresh_introspect.client_id.unwrap(), client_id);
assert!(refresh_introspect.active);
assert_eq!(
refresh_introspect.iat.unwrap(),
@ -211,14 +226,15 @@ pub async fn test(params: &mut JMAPTest) {
// ------------------------
// Request a device code
let device_code_params = AHashMap::from_iter([("client_id".to_string(), "1234".to_string())]);
let device_code_params =
AHashMap::from_iter([("client_id".to_string(), client_id.to_string())]);
let device_response: DeviceAuthResponse =
post(&metadata.device_authorization_endpoint, &device_code_params).await;
//println!("Device response: {:#?}", device_response);
// Status should be pending
let mut token_params = AHashMap::from_iter([
("client_id".to_string(), "1234".to_string()),
("client_id".to_string(), client_id.to_string()),
(
"grant_type".to_string(),
"urn:ietf:params:oauth:grant-type:device_code".to_string(),
@ -313,7 +329,7 @@ pub async fn test(params: &mut JMAPTest) {
post::<TokenResponse>(
&metadata.token_endpoint,
&AHashMap::from_iter([
("client_id".to_string(), "1234".to_string()),
("client_id".to_string(), client_id.to_string()),
("grant_type".to_string(), "refresh_token".to_string()),
("refresh_token".to_string(), token),
]),
@ -326,7 +342,7 @@ pub async fn test(params: &mut JMAPTest) {
// Refreshing the access token before expiration should not include a new refresh token
let refresh_params = AHashMap::from_iter([
("client_id".to_string(), "1234".to_string()),
("client_id".to_string(), client_id.to_string()),
("grant_type".to_string(), "refresh_token".to_string()),
("refresh_token".to_string(), refresh_token),
]);
@ -401,6 +417,35 @@ async fn post_bytes(
.unwrap()
}
async fn post_json<D: DeserializeOwned>(
url: &str,
auth_token: Option<&str>,
body: &impl Serialize,
) -> D {
let mut client = reqwest::Client::builder()
.timeout(Duration::from_millis(500))
.danger_accept_invalid_certs(true)
.build()
.unwrap_or_default()
.post(url);
if let Some(auth_token) = auth_token {
client = client.bearer_auth(auth_token);
}
serde_json::from_slice(
&client
.body(serde_json::to_string(body).unwrap().into_bytes())
.send()
.await
.unwrap()
.bytes()
.await
.unwrap(),
)
.unwrap()
}
async fn post<T: DeserializeOwned>(url: &str, params: &AHashMap<String, String>) -> T {
post_with_auth(url, None, params).await
}

View file

@ -289,6 +289,10 @@ token = "1s"
refresh-token = "3s"
refresh-token-renew = "2s"
[oauth.client-registration]
anonymous = true
required = true
[oauth.oidc]
signature-key = '''-----BEGIN PRIVATE KEY-----
MIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQDMXJI1bL3z8gaF
@ -339,7 +343,7 @@ type = "console"
level = "{LEVEL}"
multiline = false
ansi = true
disabled-events = ["network.*"]
disabled-events = ["network.*", "telemetry.webhook-error"]
[webhook."test"]
url = "http://127.0.0.1:8821/hook"