OAuth REST API

This commit is contained in:
mdecimus 2024-02-20 18:44:14 +01:00
parent 8027f135bc
commit 75bb02d13a
6 changed files with 125 additions and 51 deletions

View file

@ -21,6 +21,8 @@
* for more details.
*/
use std::sync::Arc;
use directory::{
backend::internal::{lookup::DirectoryStore, manage::ManageDirectory, PrincipalUpdate},
DirectoryError, ManagementError, Principal, QueryBy, Type,
@ -31,7 +33,11 @@ use jmap_proto::error::request::RequestError;
use serde_json::json;
use utils::config::ConfigKey;
use crate::{services::housekeeper, JMAP};
use crate::{
auth::{oauth::OAuthCodeRequest, AccessToken},
services::housekeeper,
JMAP,
};
use super::{http::ToHttpResponse, HttpRequest, JsonResponse};
@ -53,10 +59,11 @@ pub struct PrincipalResponse {
}
impl JMAP {
pub async fn handle_manage_request(
pub async fn handle_api_manage_request(
&self,
req: &HttpRequest,
body: Option<Vec<u8>>,
access_token: Arc<AccessToken>,
) -> hyper::Response<BoxBody<Bytes, hyper::Error>> {
let mut path = req.uri().path().split('/');
path.next();
@ -423,6 +430,7 @@ impl JMAP {
.into_http_response()
}
}
("oauth", _, _) => self.handle_api_request(req, body, access_token).await,
(path_1 @ ("queue" | "report"), Some(path_2), &Method::GET) => {
self.smtp
.handle_manage_request(req.uri(), req.method(), path_1, path_2)
@ -431,6 +439,38 @@ impl JMAP {
_ => RequestError::not_found().into_http_response(),
}
}
pub async fn handle_api_request(
&self,
req: &HttpRequest,
body: Option<Vec<u8>>,
access_token: Arc<AccessToken>,
) -> hyper::Response<BoxBody<Bytes, hyper::Error>> {
let mut path = req.uri().path().split('/');
path.next();
path.next();
match (path.next().unwrap_or(""), path.next(), req.method()) {
("oauth", Some("code"), &Method::POST) => {
if let Some(request) =
body.and_then(|body| serde_json::from_slice::<OAuthCodeRequest>(&body).ok())
{
JsonResponse::new(json!({
"data": self.issue_client_code(&access_token, request.client_id, request.redirect_uri),
}))
.into_http_response()
} else {
RequestError::blank(
StatusCode::BAD_REQUEST.as_u16(),
"Invalid parameters",
"Failed to deserialize modify request",
)
.into_http_response()
}
}
_ => RequestError::unauthorized().into_http_response(),
}
}
}
fn map_directory_error(err: DirectoryError) -> hyper::Response<BoxBody<Bytes, hyper::Error>> {

View file

@ -268,16 +268,25 @@ pub async fn parse_jmap_request(
}
}
"api" => {
// Make sure the user is a superuser
let body = match jmap.authenticate_headers(&req, remote_ip).await {
Ok(Some((_, access_token))) if access_token.is_super_user() => {
fetch_body(&mut req, 8192, &access_token).await
}
Ok(_) => return RequestError::unauthorized().into_http_response(),
Err(err) => return err.into_http_response(),
};
// Allow CORS preflight requests
if req.method() == Method::OPTIONS {
return ().into_http_response();
}
return jmap.handle_manage_request(&req, body).await;
// Make sure the user is a superuser
return match jmap.authenticate_headers(&req, remote_ip).await {
Ok(Some((_, access_token))) => {
let body = fetch_body(&mut req, 8192, &access_token).await;
if access_token.is_super_user() {
jmap.handle_api_manage_request(&req, body, access_token)
.await
} else {
jmap.handle_api_request(&req, body, access_token).await
}
}
Ok(None) => RequestError::unauthorized().into_http_response(),
Err(err) => err.into_http_response(),
};
}
_ => (),
}

View file

@ -73,6 +73,7 @@ pub struct OAuth {
pub metadata: String,
}
#[derive(Debug)]
pub struct OAuthCode {
pub status: AtomicU32,
pub account_id: AtomicU32,
@ -136,18 +137,19 @@ pub struct TokenRequest {
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(untagged)]
pub enum TokenResponse {
Granted {
access_token: String,
token_type: String,
expires_in: u64,
#[serde(skip_serializing_if = "Option::is_none")]
refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
scope: Option<String>,
},
Error {
error: ErrorType,
},
Granted(OAuthResponse),
Error { error: ErrorType },
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct OAuthResponse {
access_token: String,
token_type: String,
expires_in: u64,
#[serde(skip_serializing_if = "Option::is_none")]
refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
scope: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
@ -203,6 +205,12 @@ impl OAuthMetadata {
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OAuthCodeRequest {
pub client_id: String,
pub redirect_uri: Option<String>,
}
impl TokenResponse {
pub fn error(error: ErrorType) -> Self {
TokenResponse::Error { error }

View file

@ -43,8 +43,8 @@ use crate::{
};
use super::{
ErrorType, FormData, TokenResponse, CLIENT_ID_MAX_LEN, MAX_POST_LEN, RANDOM_CODE_LEN,
STATUS_AUTHORIZED, STATUS_PENDING, STATUS_TOKEN_ISSUED,
ErrorType, FormData, OAuthResponse, TokenResponse, CLIENT_ID_MAX_LEN, MAX_POST_LEN,
RANDOM_CODE_LEN, STATUS_AUTHORIZED, STATUS_PENDING, STATUS_TOKEN_ISSUED,
};
impl JMAP {
@ -83,6 +83,7 @@ impl JMAP {
true,
)
.await
.map(TokenResponse::Granted)
.unwrap_or_else(|err| {
tracing::error!("Failed to generate OAuth token: {}", err);
TokenResponse::error(ErrorType::InvalidRequest)
@ -122,6 +123,7 @@ impl JMAP {
true,
)
.await
.map(TokenResponse::Granted)
.unwrap_or_else(|err| {
tracing::error!("Failed to generate OAuth token: {}", err);
TokenResponse::error(ErrorType::InvalidRequest)
@ -153,6 +155,7 @@ impl JMAP {
time_left <= self.config.oauth_expiry_refresh_token_renew,
)
.await
.map(TokenResponse::Granted)
.unwrap_or_else(|err| {
tracing::debug!("Failed to refresh OAuth token: {}", err);
TokenResponse::error(ErrorType::InvalidGrant)
@ -174,12 +177,12 @@ impl JMAP {
.into_http_response()
}
async fn issue_token(
pub async fn issue_token(
&self,
account_id: u32,
client_id: &str,
with_refresh_token: bool,
) -> Result<TokenResponse, &'static str> {
) -> Result<OAuthResponse, &'static str> {
let password_hash = self
.directory
.query(QueryBy::Id(account_id), false)
@ -191,7 +194,7 @@ impl JMAP {
.next()
.ok_or("Failed to obtain password hash")?;
Ok(TokenResponse::Granted {
Ok(OAuthResponse {
access_token: self.encode_access_token(
"access_token",
account_id,
@ -297,8 +300,7 @@ impl JMAP {
return Err("Token expired.");
}
// Optain password hash
// Obtain password hash
let password_hash = self
.directory
.query(QueryBy::Id(account_id), false)

View file

@ -39,6 +39,7 @@ use utils::map::ttl_dashmap::TtlMap;
use crate::{
api::{http::ToHttpResponse, HtmlResponse, HttpRequest, HttpResponse},
auth::AccessToken,
JMAP,
};
@ -108,6 +109,33 @@ impl JMAP {
HtmlResponse::new(response).into_http_response()
}
pub fn issue_client_code(
&self,
access_token: &AccessToken,
client_id: String,
redirect_uri: Option<String>,
) -> String {
// Generate client code
let client_code = thread_rng()
.sample_iter(Alphanumeric)
.take(DEVICE_CODE_LEN)
.map(char::from)
.collect::<String>();
// Add client code
self.oauth_codes.insert_with_ttl(
client_code.clone(),
Arc::new(OAuthCode {
status: STATUS_AUTHORIZED.into(),
account_id: access_token.primary_id().into(),
client_id,
redirect_uri,
}),
Instant::now() + Duration::from_secs(self.config.oauth_expiry_auth_code),
);
client_code
}
// Handles POST request from the code authorization form
pub async fn handle_user_code_auth_post(
&self,
@ -141,30 +169,17 @@ impl JMAP {
if let AuthResult::Success(access_token) =
self.authenticate_plain(email, password, remote_addr).await
{
// Generate client code
let client_code = thread_rng()
.sample_iter(Alphanumeric)
.take(DEVICE_CODE_LEN)
.map(char::from)
.collect::<String>();
// Add client code
self.oauth_codes.insert_with_ttl(
client_code.clone(),
Arc::new(OAuthCode {
status: STATUS_AUTHORIZED.into(),
account_id: access_token.primary_id().into(),
client_id: code_req
auth_code = self
.issue_client_code(
&access_token,
code_req
.get("client_id")
.map(|s| s.as_str())
.unwrap_or_default()
.to_string(),
redirect_uri: code_req.get("redirect_uri").cloned(),
}),
Instant::now() + Duration::from_secs(self.config.oauth_expiry_auth_code),
);
auth_code = client_code.into();
code_req.get("redirect_uri").cloned(),
)
.into();
}
}

View file

@ -45,4 +45,4 @@ allow-lookups = true
[jmap.http]
#headers = ["Access-Control-Allow-Origin: *",
# "Access-Control-Allow-Methods: POST, GET, HEAD, OPTIONS",
# "Access-Control-Allow-Headers: *"]
# "Access-Control-Allow-Headers: Authorization, Content-Type, Accept, X-Requested-With"]