mirror of
https://github.com/stalwartlabs/mail-server.git
synced 2025-09-12 23:14:18 +08:00
WebSocket tests passing
This commit is contained in:
parent
4ff2158783
commit
86a8a5f7d5
18 changed files with 1004 additions and 331 deletions
22
Cargo.lock
generated
22
Cargo.lock
generated
|
@ -1686,7 +1686,9 @@ dependencies = [
|
|||
"sqlx",
|
||||
"store",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tracing",
|
||||
"tungstenite",
|
||||
"utils",
|
||||
]
|
||||
|
||||
|
@ -1702,6 +1704,7 @@ dependencies = [
|
|||
"maybe-async 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"parking_lot",
|
||||
"reqwest",
|
||||
"rustls 0.21.1",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
|
@ -3931,18 +3934,17 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.18.0"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd"
|
||||
checksum = "ec509ac96e9a0c43427c74f003127d953a265737636129424288d27cb5c4b12c"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"rustls 0.20.8",
|
||||
"rustls 0.21.1",
|
||||
"tokio",
|
||||
"tokio-rustls 0.23.4",
|
||||
"tokio-rustls 0.24.0",
|
||||
"tungstenite",
|
||||
"webpki",
|
||||
"webpki-roots 0.22.6",
|
||||
"webpki-roots 0.23.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -4195,18 +4197,18 @@ checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
|
|||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.18.0"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788"
|
||||
checksum = "15fba1a6d6bb030745759a9a2a588bfe8490fc8b4751a277db3a0be1c9ebbf67"
|
||||
dependencies = [
|
||||
"base64 0.13.1",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand",
|
||||
"rustls 0.20.8",
|
||||
"rustls 0.21.1",
|
||||
"sha1",
|
||||
"thiserror",
|
||||
"url",
|
||||
|
|
|
@ -3,6 +3,7 @@ pub mod echo;
|
|||
pub mod method;
|
||||
pub mod parser;
|
||||
pub mod reference;
|
||||
pub mod websocket;
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
|
@ -29,7 +30,7 @@ use crate::{
|
|||
|
||||
use self::{echo::Echo, method::MethodName};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Request {
|
||||
pub using: u32,
|
||||
pub method_calls: Vec<Call<RequestMethod>>,
|
||||
|
|
|
@ -40,152 +40,7 @@ impl Request {
|
|||
let mut parser = Parser::new(json);
|
||||
parser.next_token::<String>()?.assert(Token::DictStart)?;
|
||||
while let Some(key) = parser.next_dict_key::<u128>()? {
|
||||
match key {
|
||||
0x0067_6e69_7375 => {
|
||||
found_valid_keys = true;
|
||||
parser.next_token::<Ignore>()?.assert(Token::ArrayStart)?;
|
||||
loop {
|
||||
match parser.next_token::<Capability>()? {
|
||||
Token::String(capability) => {
|
||||
request.using |= capability as u32;
|
||||
}
|
||||
Token::Comma => (),
|
||||
Token::ArrayEnd => break,
|
||||
token => {
|
||||
return Err(token
|
||||
.error("capability", &token.to_string())
|
||||
.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
0x0073_6c6c_6143_646f_6874_656d => {
|
||||
found_valid_keys = true;
|
||||
|
||||
parser
|
||||
.next_token::<Ignore>()?
|
||||
.assert_jmap(Token::ArrayStart)?;
|
||||
loop {
|
||||
match parser.next_token::<Ignore>()? {
|
||||
Token::ArrayStart => (),
|
||||
Token::Comma => continue,
|
||||
Token::ArrayEnd => break,
|
||||
_ => {
|
||||
return Err(RequestError::not_request("Invalid JMAP request"));
|
||||
}
|
||||
};
|
||||
if request.method_calls.len() < max_calls {
|
||||
let method_name = match parser.next_token::<MethodName>() {
|
||||
Ok(Token::String(method)) => method,
|
||||
Ok(_) => {
|
||||
return Err(RequestError::not_request(
|
||||
"Invalid JMAP request",
|
||||
));
|
||||
}
|
||||
Err(Error::Method(MethodError::InvalidArguments(_))) => {
|
||||
MethodName::error()
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
parser.next_token::<Ignore>()?.assert_jmap(Token::Comma)?;
|
||||
parser.ctx = method_name.obj;
|
||||
let start_depth_array = parser.depth_array;
|
||||
let start_depth_dict = parser.depth_dict;
|
||||
|
||||
let method = match (&method_name.fnc, &method_name.obj) {
|
||||
(MethodFunction::Get, _) => {
|
||||
if method_name.obj != MethodObject::SearchSnippet {
|
||||
GetRequest::parse(&mut parser).map(RequestMethod::Get)
|
||||
} else {
|
||||
GetSearchSnippetRequest::parse(&mut parser)
|
||||
.map(RequestMethod::SearchSnippet)
|
||||
}
|
||||
}
|
||||
(MethodFunction::Query, _) => {
|
||||
QueryRequest::parse(&mut parser).map(RequestMethod::Query)
|
||||
}
|
||||
(MethodFunction::Set, _) => {
|
||||
SetRequest::parse(&mut parser).map(RequestMethod::Set)
|
||||
}
|
||||
(MethodFunction::Changes, _) => {
|
||||
ChangesRequest::parse(&mut parser)
|
||||
.map(RequestMethod::Changes)
|
||||
}
|
||||
(MethodFunction::QueryChanges, _) => {
|
||||
QueryChangesRequest::parse(&mut parser)
|
||||
.map(RequestMethod::QueryChanges)
|
||||
}
|
||||
(MethodFunction::Copy, MethodObject::Email) => {
|
||||
CopyRequest::parse(&mut parser).map(RequestMethod::Copy)
|
||||
}
|
||||
(MethodFunction::Copy, MethodObject::Blob) => {
|
||||
CopyBlobRequest::parse(&mut parser)
|
||||
.map(RequestMethod::CopyBlob)
|
||||
}
|
||||
(MethodFunction::Import, MethodObject::Email) => {
|
||||
ImportEmailRequest::parse(&mut parser)
|
||||
.map(RequestMethod::ImportEmail)
|
||||
}
|
||||
(MethodFunction::Parse, MethodObject::Email) => {
|
||||
ParseEmailRequest::parse(&mut parser)
|
||||
.map(RequestMethod::ParseEmail)
|
||||
}
|
||||
(MethodFunction::Validate, MethodObject::SieveScript) => {
|
||||
ValidateSieveScriptRequest::parse(&mut parser)
|
||||
.map(RequestMethod::ValidateScript)
|
||||
}
|
||||
(MethodFunction::Echo, MethodObject::Core) => {
|
||||
Echo::parse(&mut parser).map(RequestMethod::Echo)
|
||||
}
|
||||
_ => Err(Error::Method(MethodError::UnknownMethod(
|
||||
method_name.to_string(),
|
||||
))),
|
||||
};
|
||||
|
||||
let method = match method {
|
||||
Ok(method) => method,
|
||||
Err(Error::Method(err)) => {
|
||||
parser.skip_token(start_depth_array, start_depth_dict)?;
|
||||
RequestMethod::Error(err)
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
|
||||
parser.next_token::<Ignore>()?.assert_jmap(Token::Comma)?;
|
||||
let id = parser.next_token::<String>()?.unwrap_string("")?;
|
||||
parser
|
||||
.next_token::<Ignore>()?
|
||||
.assert_jmap(Token::ArrayEnd)?;
|
||||
request.method_calls.push(Call {
|
||||
id,
|
||||
method,
|
||||
name: method_name,
|
||||
});
|
||||
} else {
|
||||
return Err(RequestError::limit(RequestLimitError::CallsIn));
|
||||
}
|
||||
}
|
||||
}
|
||||
0x7364_4964_6574_6165_7263 => {
|
||||
found_valid_keys = true;
|
||||
let mut created_ids = HashMap::new();
|
||||
parser.next_token::<Ignore>()?.assert(Token::DictStart)?;
|
||||
while let Some(key) = parser.next_dict_key::<String>()? {
|
||||
created_ids.insert(
|
||||
key,
|
||||
parser.next_token::<Id>()?.unwrap_string("createdIds")?,
|
||||
);
|
||||
}
|
||||
request.created_ids = Some(created_ids);
|
||||
}
|
||||
_ => {
|
||||
parser.skip_token(parser.depth_array, parser.depth_dict)?;
|
||||
}
|
||||
}
|
||||
found_valid_keys |= request.parse_key(&mut parser, max_calls, key)?;
|
||||
}
|
||||
|
||||
if found_valid_keys {
|
||||
|
@ -197,6 +52,147 @@ impl Request {
|
|||
Err(RequestError::limit(RequestLimitError::Size))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn parse_key(
|
||||
&mut self,
|
||||
parser: &mut Parser,
|
||||
max_calls: usize,
|
||||
key: u128,
|
||||
) -> Result<bool, RequestError> {
|
||||
match key {
|
||||
0x0067_6e69_7375 => {
|
||||
parser.next_token::<Ignore>()?.assert(Token::ArrayStart)?;
|
||||
loop {
|
||||
match parser.next_token::<Capability>()? {
|
||||
Token::String(capability) => {
|
||||
self.using |= capability as u32;
|
||||
}
|
||||
Token::Comma => (),
|
||||
Token::ArrayEnd => break,
|
||||
token => return Err(token.error("capability", &token.to_string()).into()),
|
||||
}
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
0x0073_6c6c_6143_646f_6874_656d => {
|
||||
parser
|
||||
.next_token::<Ignore>()?
|
||||
.assert_jmap(Token::ArrayStart)?;
|
||||
loop {
|
||||
match parser.next_token::<Ignore>()? {
|
||||
Token::ArrayStart => (),
|
||||
Token::Comma => continue,
|
||||
Token::ArrayEnd => break,
|
||||
_ => {
|
||||
return Err(RequestError::not_request("Invalid JMAP request"));
|
||||
}
|
||||
};
|
||||
if self.method_calls.len() < max_calls {
|
||||
let method_name = match parser.next_token::<MethodName>() {
|
||||
Ok(Token::String(method)) => method,
|
||||
Ok(_) => {
|
||||
return Err(RequestError::not_request("Invalid JMAP request"));
|
||||
}
|
||||
Err(Error::Method(MethodError::InvalidArguments(_))) => {
|
||||
MethodName::error()
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
parser.next_token::<Ignore>()?.assert_jmap(Token::Comma)?;
|
||||
parser.ctx = method_name.obj;
|
||||
let start_depth_array = parser.depth_array;
|
||||
let start_depth_dict = parser.depth_dict;
|
||||
|
||||
let method = match (&method_name.fnc, &method_name.obj) {
|
||||
(MethodFunction::Get, _) => {
|
||||
if method_name.obj != MethodObject::SearchSnippet {
|
||||
GetRequest::parse(parser).map(RequestMethod::Get)
|
||||
} else {
|
||||
GetSearchSnippetRequest::parse(parser)
|
||||
.map(RequestMethod::SearchSnippet)
|
||||
}
|
||||
}
|
||||
(MethodFunction::Query, _) => {
|
||||
QueryRequest::parse(parser).map(RequestMethod::Query)
|
||||
}
|
||||
(MethodFunction::Set, _) => {
|
||||
SetRequest::parse(parser).map(RequestMethod::Set)
|
||||
}
|
||||
(MethodFunction::Changes, _) => {
|
||||
ChangesRequest::parse(parser).map(RequestMethod::Changes)
|
||||
}
|
||||
(MethodFunction::QueryChanges, _) => {
|
||||
QueryChangesRequest::parse(parser).map(RequestMethod::QueryChanges)
|
||||
}
|
||||
(MethodFunction::Copy, MethodObject::Email) => {
|
||||
CopyRequest::parse(parser).map(RequestMethod::Copy)
|
||||
}
|
||||
(MethodFunction::Copy, MethodObject::Blob) => {
|
||||
CopyBlobRequest::parse(parser).map(RequestMethod::CopyBlob)
|
||||
}
|
||||
(MethodFunction::Import, MethodObject::Email) => {
|
||||
ImportEmailRequest::parse(parser).map(RequestMethod::ImportEmail)
|
||||
}
|
||||
(MethodFunction::Parse, MethodObject::Email) => {
|
||||
ParseEmailRequest::parse(parser).map(RequestMethod::ParseEmail)
|
||||
}
|
||||
(MethodFunction::Validate, MethodObject::SieveScript) => {
|
||||
ValidateSieveScriptRequest::parse(parser)
|
||||
.map(RequestMethod::ValidateScript)
|
||||
}
|
||||
(MethodFunction::Echo, MethodObject::Core) => {
|
||||
Echo::parse(parser).map(RequestMethod::Echo)
|
||||
}
|
||||
_ => Err(Error::Method(MethodError::UnknownMethod(
|
||||
method_name.to_string(),
|
||||
))),
|
||||
};
|
||||
|
||||
let method = match method {
|
||||
Ok(method) => method,
|
||||
Err(Error::Method(err)) => {
|
||||
parser.skip_token(start_depth_array, start_depth_dict)?;
|
||||
RequestMethod::Error(err)
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
|
||||
parser.next_token::<Ignore>()?.assert_jmap(Token::Comma)?;
|
||||
let id = parser.next_token::<String>()?.unwrap_string("")?;
|
||||
parser
|
||||
.next_token::<Ignore>()?
|
||||
.assert_jmap(Token::ArrayEnd)?;
|
||||
self.method_calls.push(Call {
|
||||
id,
|
||||
method,
|
||||
name: method_name,
|
||||
});
|
||||
} else {
|
||||
return Err(RequestError::limit(RequestLimitError::CallsIn));
|
||||
}
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
0x7364_4964_6574_6165_7263 => {
|
||||
let mut created_ids = HashMap::new();
|
||||
parser.next_token::<Ignore>()?.assert(Token::DictStart)?;
|
||||
while let Some(key) = parser.next_dict_key::<String>()? {
|
||||
created_ids
|
||||
.insert(key, parser.next_token::<Id>()?.unwrap_string("createdIds")?);
|
||||
}
|
||||
self.created_ids = Some(created_ids);
|
||||
Ok(true)
|
||||
}
|
||||
_ => {
|
||||
parser.skip_token(parser.depth_array, parser.depth_dict)?;
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Error> for RequestError {
|
||||
|
|
236
crates/jmap-proto/src/request/websocket.rs
Normal file
236
crates/jmap-proto/src/request/websocket.rs
Normal file
|
@ -0,0 +1,236 @@
|
|||
use std::{borrow::Cow, collections::HashMap};
|
||||
|
||||
use crate::{
|
||||
error::request::{RequestError, RequestErrorType, RequestLimitError},
|
||||
parser::{json::Parser, Error, JsonObjectParser, Token},
|
||||
request::Call,
|
||||
response::{serialize::serialize_hex, Response, ResponseMethod},
|
||||
types::{id::Id, state::State, type_state::TypeState},
|
||||
};
|
||||
use utils::map::vec_map::VecMap;
|
||||
|
||||
use super::{Request, RequestProperty};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WebSocketRequest {
|
||||
pub id: Option<String>,
|
||||
pub request: Request,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct WebSocketResponse {
|
||||
#[serde(rename = "@type")]
|
||||
_type: WebSocketResponseType,
|
||||
|
||||
#[serde(rename = "methodResponses")]
|
||||
method_responses: Vec<Call<ResponseMethod>>,
|
||||
|
||||
#[serde(rename = "sessionState")]
|
||||
#[serde(serialize_with = "serialize_hex")]
|
||||
session_state: u32,
|
||||
|
||||
#[serde(rename(deserialize = "createdIds"))]
|
||||
#[serde(skip_serializing_if = "HashMap::is_empty")]
|
||||
created_ids: HashMap<String, Id>,
|
||||
|
||||
#[serde(rename = "requestId")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
request_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, serde::Serialize)]
|
||||
pub enum WebSocketResponseType {
|
||||
Response,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, PartialEq, Eq)]
|
||||
pub struct WebSocketPushEnable {
|
||||
pub data_types: Vec<TypeState>,
|
||||
pub push_state: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum WebSocketMessage {
|
||||
Request(WebSocketRequest),
|
||||
PushEnable(WebSocketPushEnable),
|
||||
PushDisable,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
pub enum WebSocketStateChangeType {
|
||||
StateChange,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
pub struct WebSocketStateChange {
|
||||
#[serde(rename = "@type")]
|
||||
pub type_: WebSocketStateChangeType,
|
||||
pub changed: VecMap<Id, VecMap<TypeState, State>>,
|
||||
#[serde(rename = "pushState")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
push_state: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct WebSocketRequestError {
|
||||
#[serde(rename = "@type")]
|
||||
pub type_: WebSocketRequestErrorType,
|
||||
|
||||
#[serde(rename = "type")]
|
||||
p_type: RequestErrorType,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
limit: Option<RequestLimitError>,
|
||||
status: u16,
|
||||
detail: Cow<'static, str>,
|
||||
|
||||
#[serde(rename = "requestId")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub request_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
pub enum WebSocketRequestErrorType {
|
||||
RequestError,
|
||||
}
|
||||
|
||||
enum MessageType {
|
||||
Request,
|
||||
PushEnable,
|
||||
PushDisable,
|
||||
None,
|
||||
}
|
||||
|
||||
impl WebSocketMessage {
|
||||
pub fn parse(
|
||||
json: &[u8],
|
||||
max_calls: usize,
|
||||
max_size: usize,
|
||||
) -> Result<Self, WebSocketRequestError> {
|
||||
if json.len() <= max_size {
|
||||
let mut message_type = MessageType::None;
|
||||
let mut request = WebSocketRequest {
|
||||
id: None,
|
||||
request: Request::default(),
|
||||
};
|
||||
let mut push_enable = WebSocketPushEnable::default();
|
||||
|
||||
let mut found_request_keys = false;
|
||||
let mut found_push_keys = false;
|
||||
|
||||
let mut parser = Parser::new(json);
|
||||
parser.next_token::<String>()?.assert(Token::DictStart)?;
|
||||
while let Some(key) = parser.next_dict_key::<u128>()? {
|
||||
match key {
|
||||
0x0065_7079_7440 => {
|
||||
let rt = parser
|
||||
.next_token::<RequestProperty>()?
|
||||
.unwrap_string("@type")?;
|
||||
message_type = match (rt.hash[0], rt.hash[1]) {
|
||||
(0x0074_7365_7571_6552, 0) => MessageType::Request,
|
||||
(0x616e_4568_7375_5074_656b_636f_5362_6557, 0x656c62) => {
|
||||
MessageType::PushEnable
|
||||
}
|
||||
(0x7369_4468_7375_5074_656b_636f_5362_6557, 0x656c6261) => {
|
||||
MessageType::PushDisable
|
||||
}
|
||||
_ => MessageType::None,
|
||||
};
|
||||
}
|
||||
0x0073_6570_7954_6174_6164 => {
|
||||
push_enable.data_types =
|
||||
<Option<Vec<TypeState>>>::parse(&mut parser)?.unwrap_or_default();
|
||||
found_push_keys = true;
|
||||
}
|
||||
0x0065_7461_7453_6873_7570 => {
|
||||
push_enable.push_state = parser
|
||||
.next_token::<String>()?
|
||||
.unwrap_string_or_null("pushState")?;
|
||||
found_push_keys = true;
|
||||
}
|
||||
0x6469 => {
|
||||
request.id = parser.next_token::<String>()?.unwrap_string_or_null("id")?;
|
||||
}
|
||||
_ => {
|
||||
found_request_keys |=
|
||||
request.request.parse_key(&mut parser, max_calls, key)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match message_type {
|
||||
MessageType::Request if found_request_keys => {
|
||||
Ok(WebSocketMessage::Request(request))
|
||||
}
|
||||
MessageType::PushEnable if found_push_keys => {
|
||||
Ok(WebSocketMessage::PushEnable(push_enable))
|
||||
}
|
||||
MessageType::PushDisable if !found_request_keys && !found_push_keys => {
|
||||
Ok(WebSocketMessage::PushDisable)
|
||||
}
|
||||
_ => Err(RequestError::not_request("Invalid WebSocket JMAP request").into()),
|
||||
}
|
||||
} else {
|
||||
Err(RequestError::limit(RequestLimitError::Size).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WebSocketRequestError {
|
||||
pub fn from_error(error: RequestError, request_id: Option<String>) -> Self {
|
||||
Self {
|
||||
type_: WebSocketRequestErrorType::RequestError,
|
||||
p_type: error.p_type,
|
||||
limit: error.limit,
|
||||
status: error.status,
|
||||
detail: error.detail,
|
||||
request_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_json(&self) -> String {
|
||||
serde_json::to_string(self).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RequestError> for WebSocketRequestError {
|
||||
fn from(value: RequestError) -> Self {
|
||||
Self::from_error(value, None)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Error> for WebSocketRequestError {
|
||||
fn from(value: Error) -> Self {
|
||||
RequestError::from(value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl WebSocketResponse {
|
||||
pub fn from_response(response: Response, request_id: Option<String>) -> Self {
|
||||
Self {
|
||||
_type: WebSocketResponseType::Response,
|
||||
method_responses: response.method_responses,
|
||||
session_state: response.session_state,
|
||||
created_ids: response.created_ids,
|
||||
request_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_json(&self) -> String {
|
||||
serde_json::to_string(self).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl WebSocketStateChange {
|
||||
pub fn new(push_state: Option<String>) -> Self {
|
||||
WebSocketStateChange {
|
||||
type_: WebSocketStateChangeType::StateChange,
|
||||
changed: VecMap::new(),
|
||||
push_state,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_json(&self) -> String {
|
||||
serde_json::to_string(self).unwrap()
|
||||
}
|
||||
}
|
|
@ -51,6 +51,7 @@ pub struct Response {
|
|||
pub session_state: u32,
|
||||
|
||||
#[serde(rename = "createdIds")]
|
||||
#[serde(skip_serializing_if = "HashMap::is_empty")]
|
||||
pub created_ids: HashMap<String, Id>,
|
||||
}
|
||||
|
||||
|
|
|
@ -34,6 +34,8 @@ p256 = { version = "0.13", features = ["ecdh"] }
|
|||
hkdf = "0.12.3"
|
||||
sha2 = "0.10.1"
|
||||
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"]}
|
||||
tokio-tungstenite = "0.19.0"
|
||||
tungstenite = "0.19.0"
|
||||
|
||||
[dev-dependencies]
|
||||
ece = "2.2"
|
||||
|
|
|
@ -101,6 +101,9 @@ impl crate::Config {
|
|||
oauth_max_auth_attempts: settings.property_or_static("oauth.max-auth-attempts", "3")?,
|
||||
event_source_throttle: settings
|
||||
.property_or_static("jmap.event-source.throttle", "1s")?,
|
||||
web_socket_throttle: settings.property_or_static("jmap.web-socket.throttle", "1s")?,
|
||||
web_socket_timeout: settings.property_or_static("jmap.web-socket.timeout", "10m")?,
|
||||
web_socket_heartbeat: settings.property_or_static("jmap.web-socket.heartbeat", "1m")?,
|
||||
push_max_total: settings.property_or_static("jmap.push.max-total", "100")?,
|
||||
};
|
||||
config.add_capabilites(settings);
|
||||
|
|
|
@ -24,7 +24,7 @@ struct Ping {
|
|||
impl JMAP {
|
||||
pub async fn handle_event_source(
|
||||
&self,
|
||||
req: &HttpRequest,
|
||||
req: HttpRequest,
|
||||
acl_token: Arc<AclToken>,
|
||||
) -> HttpResponse {
|
||||
// Parse query
|
||||
|
|
|
@ -10,6 +10,7 @@ use hyper::{
|
|||
};
|
||||
use jmap_proto::{
|
||||
error::request::{RequestError, RequestLimitError},
|
||||
request::Request,
|
||||
response::Response,
|
||||
types::{blob::BlobId, id::Id},
|
||||
};
|
||||
|
@ -23,39 +24,96 @@ use crate::{
|
|||
auth::oauth::OAuthMetadata,
|
||||
blob::{DownloadResponse, UploadResponse},
|
||||
services::state,
|
||||
websocket::upgrade::upgrade_websocket_connection,
|
||||
JMAP,
|
||||
};
|
||||
|
||||
use super::{session::Session, HtmlResponse, HttpResponse, JmapSessionManager, JsonResponse};
|
||||
use super::{
|
||||
session::Session, HtmlResponse, HttpRequest, HttpResponse, JmapSessionManager, JsonResponse,
|
||||
};
|
||||
|
||||
impl JMAP {
|
||||
pub async fn parse_request(
|
||||
&self,
|
||||
req: &mut hyper::Request<hyper::body::Incoming>,
|
||||
remote_ip: IpAddr,
|
||||
instance: &Arc<ServerInstance>,
|
||||
) -> HttpResponse {
|
||||
let mut path = req.uri().path().split('/');
|
||||
path.next();
|
||||
pub async fn parse_jmap_request(
|
||||
jmap: Arc<JMAP>,
|
||||
mut req: HttpRequest,
|
||||
remote_ip: IpAddr,
|
||||
instance: Arc<ServerInstance>,
|
||||
) -> HttpResponse {
|
||||
let mut path = req.uri().path().split('/');
|
||||
path.next();
|
||||
|
||||
match path.next().unwrap_or("") {
|
||||
"jmap" => {
|
||||
// Authenticate request
|
||||
let (_in_flight, acl_token) = match self.authenticate_headers(req, remote_ip).await
|
||||
{
|
||||
Ok(Some(session)) => session,
|
||||
Ok(None) => return RequestError::unauthorized().into_http_response(),
|
||||
Err(err) => return err.into_http_response(),
|
||||
};
|
||||
match path.next().unwrap_or("") {
|
||||
"jmap" => {
|
||||
// Authenticate request
|
||||
let (_in_flight, acl_token) = match jmap.authenticate_headers(&req, remote_ip).await {
|
||||
Ok(Some(session)) => session,
|
||||
Ok(None) => return RequestError::unauthorized().into_http_response(),
|
||||
Err(err) => return err.into_http_response(),
|
||||
};
|
||||
|
||||
match (path.next().unwrap_or(""), req.method()) {
|
||||
("", &Method::POST) => {
|
||||
return match fetch_body(req, self.config.request_max_size).await {
|
||||
match (path.next().unwrap_or(""), req.method()) {
|
||||
("", &Method::POST) => {
|
||||
return match fetch_body(&mut req, jmap.config.request_max_size)
|
||||
.await
|
||||
.and_then(|bytes| {
|
||||
Request::parse(
|
||||
&bytes,
|
||||
jmap.config.request_max_calls,
|
||||
jmap.config.request_max_size,
|
||||
)
|
||||
}) {
|
||||
Ok(request) => {
|
||||
//let _ = println!("<- {}", String::from_utf8_lossy(&bytes));
|
||||
|
||||
match jmap.handle_request(request, acl_token, &instance).await {
|
||||
Ok(response) => response.into_http_response(),
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
Err(err) => err.into_http_response(),
|
||||
};
|
||||
}
|
||||
("download", &Method::GET) => {
|
||||
if let (Some(_), Some(blob_id), Some(name)) = (
|
||||
path.next().and_then(|p| Id::from_bytes(p.as_bytes())),
|
||||
path.next().and_then(BlobId::from_base32),
|
||||
path.next(),
|
||||
) {
|
||||
return match jmap.blob_download(&blob_id, &acl_token).await {
|
||||
Ok(Some(blob)) => DownloadResponse {
|
||||
filename: name.to_string(),
|
||||
content_type: req
|
||||
.uri()
|
||||
.query()
|
||||
.and_then(|q| {
|
||||
form_urlencoded::parse(q.as_bytes())
|
||||
.find(|(k, _)| k == "accept")
|
||||
.map(|(_, v)| v.into_owned())
|
||||
})
|
||||
.unwrap_or("application/octet-stream".to_string()),
|
||||
blob,
|
||||
}
|
||||
.into_http_response(),
|
||||
Ok(None) => RequestError::not_found().into_http_response(),
|
||||
Err(_) => RequestError::internal_server_error().into_http_response(),
|
||||
};
|
||||
}
|
||||
}
|
||||
("upload", &Method::POST) => {
|
||||
if let Some(account_id) = path.next().and_then(|p| Id::from_bytes(p.as_bytes()))
|
||||
{
|
||||
return match fetch_body(&mut req, jmap.config.upload_max_size).await {
|
||||
Ok(bytes) => {
|
||||
//let delete = "fd";
|
||||
//println!("<- {}", String::from_utf8_lossy(&bytes));
|
||||
|
||||
match self.handle_request(&bytes, acl_token, instance).await {
|
||||
match jmap
|
||||
.blob_upload(
|
||||
account_id,
|
||||
req.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("application/octet-stream"),
|
||||
&bytes,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => response.into_http_response(),
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
|
@ -63,141 +121,90 @@ impl JMAP {
|
|||
Err(err) => err.into_http_response(),
|
||||
};
|
||||
}
|
||||
("download", &Method::GET) => {
|
||||
if let (Some(_), Some(blob_id), Some(name)) = (
|
||||
path.next().and_then(|p| Id::from_bytes(p.as_bytes())),
|
||||
path.next().and_then(BlobId::from_base32),
|
||||
path.next(),
|
||||
) {
|
||||
return match self.blob_download(&blob_id, &acl_token).await {
|
||||
Ok(Some(blob)) => DownloadResponse {
|
||||
filename: name.to_string(),
|
||||
content_type: req
|
||||
.uri()
|
||||
.query()
|
||||
.and_then(|q| {
|
||||
form_urlencoded::parse(q.as_bytes())
|
||||
.find(|(k, _)| k == "accept")
|
||||
.map(|(_, v)| v.into_owned())
|
||||
})
|
||||
.unwrap_or("application/octet-stream".to_string()),
|
||||
blob,
|
||||
}
|
||||
.into_http_response(),
|
||||
Ok(None) => RequestError::not_found().into_http_response(),
|
||||
Err(_) => {
|
||||
RequestError::internal_server_error().into_http_response()
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
("upload", &Method::POST) => {
|
||||
if let Some(account_id) =
|
||||
path.next().and_then(|p| Id::from_bytes(p.as_bytes()))
|
||||
{
|
||||
return match fetch_body(req, self.config.upload_max_size).await {
|
||||
Ok(bytes) => {
|
||||
match self
|
||||
.blob_upload(
|
||||
account_id,
|
||||
req.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("application/octet-stream"),
|
||||
&bytes,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => response.into_http_response(),
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
Err(err) => err.into_http_response(),
|
||||
};
|
||||
}
|
||||
}
|
||||
("eventsource", &Method::GET) => {
|
||||
return self.handle_event_source(req, acl_token).await
|
||||
}
|
||||
("ws", &Method::GET) => {
|
||||
todo!()
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
".well-known" => match (path.next().unwrap_or(""), req.method()) {
|
||||
("jmap", &Method::GET) => {
|
||||
// Authenticate request
|
||||
let (_in_flight, acl_token) =
|
||||
match self.authenticate_headers(req, remote_ip).await {
|
||||
Ok(Some(session)) => session,
|
||||
Ok(None) => return RequestError::unauthorized().into_http_response(),
|
||||
Err(err) => return err.into_http_response(),
|
||||
};
|
||||
|
||||
return match self.handle_session_resource(instance, acl_token).await {
|
||||
Ok(session) => session.into_http_response(),
|
||||
Err(err) => err.into_http_response(),
|
||||
};
|
||||
("eventsource", &Method::GET) => {
|
||||
return jmap.handle_event_source(req, acl_token).await
|
||||
}
|
||||
("oauth-authorization-server", &Method::GET) => {
|
||||
let remote_addr = self.build_remote_addr(req, remote_ip);
|
||||
// Limit anonymous requests
|
||||
return match self.is_anonymous_allowed(remote_addr) {
|
||||
Ok(_) => JsonResponse::new(OAuthMetadata::new(&instance.data))
|
||||
.into_http_response(),
|
||||
Err(err) => err.into_http_response(),
|
||||
};
|
||||
("ws", &Method::GET) => {
|
||||
return upgrade_websocket_connection(jmap, req, acl_token, instance.clone())
|
||||
.await;
|
||||
}
|
||||
_ => (),
|
||||
},
|
||||
"auth" => {
|
||||
let remote_addr = self.build_remote_addr(req, remote_ip);
|
||||
}
|
||||
}
|
||||
".well-known" => match (path.next().unwrap_or(""), req.method()) {
|
||||
("jmap", &Method::GET) => {
|
||||
// Authenticate request
|
||||
let (_in_flight, acl_token) = match jmap.authenticate_headers(&req, remote_ip).await
|
||||
{
|
||||
Ok(Some(session)) => session,
|
||||
Ok(None) => return RequestError::unauthorized().into_http_response(),
|
||||
Err(err) => return err.into_http_response(),
|
||||
};
|
||||
|
||||
match (path.next().unwrap_or(""), req.method()) {
|
||||
("", &Method::GET) => {
|
||||
return match self.is_anonymous_allowed(remote_addr) {
|
||||
Ok(_) => self.handle_user_device_auth(req).await,
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
return match jmap.handle_session_resource(instance, acl_token).await {
|
||||
Ok(session) => session.into_http_response(),
|
||||
Err(err) => err.into_http_response(),
|
||||
};
|
||||
}
|
||||
("oauth-authorization-server", &Method::GET) => {
|
||||
let remote_addr = jmap.build_remote_addr(&req, remote_ip);
|
||||
// Limit anonymous requests
|
||||
return match jmap.is_anonymous_allowed(remote_addr) {
|
||||
Ok(_) => {
|
||||
JsonResponse::new(OAuthMetadata::new(&instance.data)).into_http_response()
|
||||
}
|
||||
("", &Method::POST) => {
|
||||
return match self.is_auth_allowed(remote_addr) {
|
||||
Ok(_) => self.handle_user_device_auth_post(req).await,
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
("code", &Method::GET) => {
|
||||
return match self.is_anonymous_allowed(remote_addr) {
|
||||
Ok(_) => self.handle_user_code_auth(req).await,
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
("code", &Method::POST) => {
|
||||
return match self.is_auth_allowed(remote_addr) {
|
||||
Ok(_) => self.handle_user_code_auth_post(req).await,
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
("device", &Method::POST) => {
|
||||
return match self.is_anonymous_allowed(remote_addr) {
|
||||
Ok(_) => self.handle_device_auth(req, instance).await,
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
("token", &Method::POST) => {
|
||||
return match self.is_anonymous_allowed(remote_addr) {
|
||||
Ok(_) => self.handle_token_request(req).await,
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
Err(err) => err.into_http_response(),
|
||||
};
|
||||
}
|
||||
_ => (),
|
||||
},
|
||||
"auth" => {
|
||||
let remote_addr = jmap.build_remote_addr(&req, remote_ip);
|
||||
|
||||
match (path.next().unwrap_or(""), req.method()) {
|
||||
("", &Method::GET) => {
|
||||
return match jmap.is_anonymous_allowed(remote_addr) {
|
||||
Ok(_) => jmap.handle_user_device_auth(&mut req).await,
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
("", &Method::POST) => {
|
||||
return match jmap.is_auth_allowed(remote_addr) {
|
||||
Ok(_) => jmap.handle_user_device_auth_post(&mut req).await,
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
("code", &Method::GET) => {
|
||||
return match jmap.is_anonymous_allowed(remote_addr) {
|
||||
Ok(_) => jmap.handle_user_code_auth(&mut req).await,
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
("code", &Method::POST) => {
|
||||
return match jmap.is_auth_allowed(remote_addr) {
|
||||
Ok(_) => jmap.handle_user_code_auth_post(&mut req).await,
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
("device", &Method::POST) => {
|
||||
return match jmap.is_anonymous_allowed(remote_addr) {
|
||||
Ok(_) => jmap.handle_device_auth(&mut req, instance).await,
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
("token", &Method::POST) => {
|
||||
return match jmap.is_anonymous_allowed(remote_addr) {
|
||||
Ok(_) => jmap.handle_token_request(&mut req).await,
|
||||
Err(err) => err.into_http_response(),
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
RequestError::not_found().into_http_response()
|
||||
_ => (),
|
||||
}
|
||||
RequestError::not_found().into_http_response()
|
||||
}
|
||||
|
||||
impl SessionManager for JmapSessionManager {
|
||||
|
@ -246,7 +253,7 @@ impl SessionManager for JmapSessionManager {
|
|||
}
|
||||
}
|
||||
|
||||
async fn handle_request<T: AsyncRead + AsyncWrite + Unpin + 'static>(
|
||||
async fn handle_request<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
jmap: Arc<JMAP>,
|
||||
session: SessionData<T>,
|
||||
) {
|
||||
|
@ -257,27 +264,25 @@ async fn handle_request<T: AsyncRead + AsyncWrite + Unpin + 'static>(
|
|||
.keep_alive(true)
|
||||
.serve_connection(
|
||||
session.stream,
|
||||
service_fn(|mut req: hyper::Request<body::Incoming>| {
|
||||
service_fn(|req: hyper::Request<body::Incoming>| {
|
||||
let jmap = jmap.clone();
|
||||
let span = span.clone();
|
||||
let instance = session.instance.clone();
|
||||
|
||||
async move {
|
||||
let response = jmap
|
||||
.parse_request(&mut req, session.remote_ip, &instance)
|
||||
.await;
|
||||
|
||||
tracing::debug!(
|
||||
parent: &span,
|
||||
event = "request",
|
||||
uri = req.uri().to_string(),
|
||||
status = response.status().to_string(),
|
||||
);
|
||||
|
||||
let response = parse_jmap_request(jmap, req, session.remote_ip, instance).await;
|
||||
|
||||
Ok::<_, hyper::Error>(response)
|
||||
}
|
||||
}),
|
||||
)
|
||||
.with_upgrades()
|
||||
.await
|
||||
{
|
||||
tracing::debug!(
|
||||
|
@ -289,10 +294,7 @@ async fn handle_request<T: AsyncRead + AsyncWrite + Unpin + 'static>(
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn fetch_body(
|
||||
req: &mut hyper::Request<hyper::body::Incoming>,
|
||||
max_size: usize,
|
||||
) -> Result<Vec<u8>, RequestError> {
|
||||
pub async fn fetch_body(req: &mut HttpRequest, max_size: usize) -> Result<Vec<u8>, RequestError> {
|
||||
let mut bytes = Vec::with_capacity(1024);
|
||||
while let Some(Ok(frame)) = req.frame().await {
|
||||
if let Some(data) = frame.data_ref() {
|
||||
|
|
|
@ -17,15 +17,10 @@ use crate::{auth::AclToken, JMAP};
|
|||
impl JMAP {
|
||||
pub async fn handle_request(
|
||||
&self,
|
||||
bytes: &[u8],
|
||||
request: Request,
|
||||
acl_token: Arc<AclToken>,
|
||||
instance: &Arc<ServerInstance>,
|
||||
) -> Result<Response, RequestError> {
|
||||
let request = Request::parse(
|
||||
bytes,
|
||||
self.config.request_max_calls,
|
||||
self.config.request_max_size,
|
||||
)?;
|
||||
let mut response = Response::new(
|
||||
acl_token.state(),
|
||||
request.created_ids.unwrap_or_default(),
|
||||
|
|
|
@ -143,7 +143,7 @@ pub struct BaseCapabilities {
|
|||
impl JMAP {
|
||||
pub async fn handle_session_resource(
|
||||
&self,
|
||||
instance: &ServerInstance,
|
||||
instance: Arc<ServerInstance>,
|
||||
acl_token: Arc<AclToken>,
|
||||
) -> Result<Session, RequestError> {
|
||||
let mut session = Session::new(&instance.data, &self.config.capabilities);
|
||||
|
|
|
@ -31,7 +31,7 @@ impl JMAP {
|
|||
pub async fn handle_device_auth(
|
||||
&self,
|
||||
req: &mut HttpRequest,
|
||||
instance: &ServerInstance,
|
||||
instance: Arc<ServerInstance>,
|
||||
) -> HttpResponse {
|
||||
// Parse form
|
||||
let client_id = match parse_form_data(req)
|
||||
|
|
|
@ -46,6 +46,7 @@ pub mod sieve;
|
|||
pub mod submission;
|
||||
pub mod thread;
|
||||
pub mod vacation;
|
||||
pub mod websocket;
|
||||
|
||||
pub const SUPERUSER_ID: u32 = 0;
|
||||
pub const LONG_SLUMBER: Duration = Duration::from_secs(60 * 60 * 24);
|
||||
|
@ -103,6 +104,10 @@ pub struct Config {
|
|||
pub event_source_throttle: Duration,
|
||||
pub push_max_total: usize,
|
||||
|
||||
pub web_socket_throttle: Duration,
|
||||
pub web_socket_timeout: Duration,
|
||||
pub web_socket_heartbeat: Duration,
|
||||
|
||||
pub oauth_key: String,
|
||||
pub oauth_expiry_user_code: u64,
|
||||
pub oauth_expiry_auth_code: u64,
|
||||
|
|
2
crates/jmap/src/websocket/mod.rs
Normal file
2
crates/jmap/src/websocket/mod.rs
Normal file
|
@ -0,0 +1,2 @@
|
|||
pub mod stream;
|
||||
pub mod upgrade;
|
192
crates/jmap/src/websocket/stream.rs
Normal file
192
crates/jmap/src/websocket/stream.rs
Normal file
|
@ -0,0 +1,192 @@
|
|||
use std::{sync::Arc, time::Instant};
|
||||
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use hyper::upgrade::Upgraded;
|
||||
use jmap_proto::{
|
||||
error::request::RequestError,
|
||||
request::websocket::{
|
||||
WebSocketMessage, WebSocketRequestError, WebSocketResponse, WebSocketStateChange,
|
||||
},
|
||||
types::type_state::TypeState,
|
||||
};
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tungstenite::Message;
|
||||
use utils::{listener::ServerInstance, map::bitmap::Bitmap};
|
||||
|
||||
use crate::{auth::AclToken, JMAP};
|
||||
|
||||
impl JMAP {
|
||||
pub async fn handle_websocket_stream(
|
||||
&self,
|
||||
mut stream: WebSocketStream<Upgraded>,
|
||||
acl_token: Arc<AclToken>,
|
||||
instance: Arc<ServerInstance>,
|
||||
) {
|
||||
let span = tracing::info_span!(
|
||||
"WebSocket connection established",
|
||||
"account_id" = acl_token.primary_id(),
|
||||
"url" = instance.data,
|
||||
);
|
||||
|
||||
// Set timeouts
|
||||
let throttle = self.config.web_socket_throttle;
|
||||
let timeout = self.config.web_socket_timeout;
|
||||
let heartbeat = self.config.web_socket_heartbeat;
|
||||
let mut last_request = Instant::now();
|
||||
let mut last_changes_sent = Instant::now() - throttle;
|
||||
let mut last_heartbeat = Instant::now() - heartbeat;
|
||||
let mut next_event = heartbeat;
|
||||
|
||||
// Register with state manager
|
||||
let mut change_rx = if let Some(change_rx) = self
|
||||
.subscribe_state_manager(
|
||||
acl_token.primary_id(),
|
||||
acl_token.primary_id(),
|
||||
Bitmap::all(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
change_rx
|
||||
} else {
|
||||
let _ = stream
|
||||
.send(Message::Text(
|
||||
WebSocketRequestError::from(RequestError::internal_server_error()).to_json(),
|
||||
))
|
||||
.await;
|
||||
return;
|
||||
};
|
||||
let mut changes = WebSocketStateChange::new(None);
|
||||
let mut change_types: Bitmap<TypeState> = Bitmap::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
event = tokio::time::timeout(next_event, stream.next()) => {
|
||||
match event {
|
||||
Ok(Some(Ok(event))) => {
|
||||
match event {
|
||||
Message::Text(text) => {
|
||||
let response = match WebSocketMessage::parse(
|
||||
text.as_bytes(),
|
||||
self.config.request_max_calls,
|
||||
self.config.request_max_size,
|
||||
) {
|
||||
Ok(WebSocketMessage::Request(request)) => {
|
||||
match self
|
||||
.handle_request(
|
||||
request.request,
|
||||
acl_token.clone(),
|
||||
&instance,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
WebSocketResponse::from_response(response, request.id)
|
||||
.to_json()
|
||||
}
|
||||
Err(err) => {
|
||||
WebSocketRequestError::from_error(err, request.id)
|
||||
.to_json()
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(WebSocketMessage::PushEnable(push_enable)) => {
|
||||
change_types = if !push_enable.data_types.is_empty() {
|
||||
push_enable.data_types.into()
|
||||
} else {
|
||||
Bitmap::all()
|
||||
};
|
||||
continue;
|
||||
}
|
||||
Ok(WebSocketMessage::PushDisable) => {
|
||||
change_types = Bitmap::new();
|
||||
continue;
|
||||
}
|
||||
Err(err) => err.to_json(),
|
||||
};
|
||||
if let Err(err) = stream.send(Message::Text(response)).await {
|
||||
tracing::debug!(parent: &span, error = ?err, "Failed to send text message");
|
||||
}
|
||||
}
|
||||
Message::Ping(bytes) => {
|
||||
if let Err(err) = stream.send(Message::Pong(bytes)).await {
|
||||
tracing::debug!(parent: &span, error = ?err, "Failed to send pong message");
|
||||
}
|
||||
}
|
||||
Message::Close(frame) => {
|
||||
let _ = stream.close(frame).await;
|
||||
break;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
last_request = Instant::now();
|
||||
last_heartbeat = Instant::now();
|
||||
}
|
||||
Ok(Some(Err(err))) => {
|
||||
tracing::debug!(parent: &span, error = ?err, "Websocket error");
|
||||
break;
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(_) => {
|
||||
// Verify timeout
|
||||
if last_request.elapsed() > timeout {
|
||||
tracing::debug!(
|
||||
parent: &span,
|
||||
event = "disconnect",
|
||||
"Disconnecting idle client"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
state_change = change_rx.recv() => {
|
||||
if let Some(state_change) = state_change {
|
||||
if !change_types.is_empty() && state_change
|
||||
.types
|
||||
.iter()
|
||||
.any(|(t, _)| change_types.contains(*t))
|
||||
{
|
||||
for (type_state, change_id) in state_change.types {
|
||||
changes
|
||||
.changed
|
||||
.get_mut_or_insert(state_change.account_id.into())
|
||||
.set(type_state, change_id.into());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::debug!(
|
||||
parent: &span,
|
||||
event = "channel-closed",
|
||||
"Disconnecting client, channel closed"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changes.changed.is_empty() {
|
||||
// Send any queued changes
|
||||
let elapsed = last_changes_sent.elapsed();
|
||||
if elapsed >= throttle {
|
||||
if let Err(err) = stream.send(Message::Text(changes.to_json())).await {
|
||||
tracing::debug!(parent: &span, error = ?err, "Failed to send state change message");
|
||||
}
|
||||
changes.changed.clear();
|
||||
last_changes_sent = Instant::now();
|
||||
last_heartbeat = Instant::now();
|
||||
next_event = heartbeat;
|
||||
} else {
|
||||
next_event = throttle - elapsed;
|
||||
}
|
||||
} else if last_heartbeat.elapsed() > heartbeat {
|
||||
if let Err(err) = stream.send(Message::Ping(vec![])).await {
|
||||
tracing::debug!(parent: &span, error = ?err, "Failed to send ping message");
|
||||
break;
|
||||
}
|
||||
last_heartbeat = Instant::now();
|
||||
next_event = heartbeat;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
88
crates/jmap/src/websocket/upgrade.rs
Normal file
88
crates/jmap/src/websocket/upgrade.rs
Normal file
|
@ -0,0 +1,88 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::{body::Bytes, Response, StatusCode};
|
||||
use jmap_proto::error::request::RequestError;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tungstenite::{handshake::derive_accept_key, protocol::Role};
|
||||
use utils::listener::ServerInstance;
|
||||
|
||||
use crate::{
|
||||
api::{http::ToHttpResponse, HttpRequest, HttpResponse},
|
||||
auth::AclToken,
|
||||
JMAP,
|
||||
};
|
||||
|
||||
pub async fn upgrade_websocket_connection(
|
||||
jmap: Arc<JMAP>,
|
||||
req: HttpRequest,
|
||||
acl_token: Arc<AclToken>,
|
||||
instance: Arc<ServerInstance>,
|
||||
) -> HttpResponse {
|
||||
let headers = req.headers();
|
||||
if headers
|
||||
.get(hyper::header::CONNECTION)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
!= Some("Upgrade")
|
||||
|| headers
|
||||
.get(hyper::header::UPGRADE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
!= Some("websocket")
|
||||
{
|
||||
return RequestError::blank(
|
||||
StatusCode::BAD_REQUEST.as_u16(),
|
||||
"WebSocket upgrade failed",
|
||||
"Missing or Invalid Connection or Upgrade headers.",
|
||||
)
|
||||
.into_http_response();
|
||||
}
|
||||
let derived_key = match (
|
||||
headers
|
||||
.get("Sec-WebSocket-Key")
|
||||
.and_then(|h| h.to_str().ok()),
|
||||
headers
|
||||
.get("Sec-WebSocket-Version")
|
||||
.and_then(|h| h.to_str().ok()),
|
||||
) {
|
||||
(Some(key), Some(version)) if version == "13" => derive_accept_key(key.as_bytes()),
|
||||
_ => {
|
||||
return RequestError::blank(
|
||||
StatusCode::BAD_REQUEST.as_u16(),
|
||||
"WebSocket upgrade failed",
|
||||
"Missing or Invalid Sec-WebSocket-Key headers.",
|
||||
)
|
||||
.into_http_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Spawn WebSocket connection
|
||||
tokio::spawn(async move {
|
||||
// Upgrade connection
|
||||
match hyper::upgrade::on(req).await {
|
||||
Ok(upgraded) => {
|
||||
jmap.handle_websocket_stream(
|
||||
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
|
||||
acl_token,
|
||||
instance,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!("WebSocket upgrade failed: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Response::builder()
|
||||
.status(hyper::StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(hyper::header::CONNECTION, "upgrade")
|
||||
.header(hyper::header::UPGRADE, "websocket")
|
||||
.header("Sec-WebSocket-Accept", &derived_key)
|
||||
.header("Sec-WebSocket-Protocol", "jmap")
|
||||
.body(
|
||||
Full::new(Bytes::from("Switching to WebSocket protocol"))
|
||||
.map_err(|never| match never {})
|
||||
.boxed(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
|
@ -29,6 +29,7 @@ pub mod sieve_script;
|
|||
pub mod thread_get;
|
||||
pub mod thread_merge;
|
||||
pub mod vacation_response;
|
||||
pub mod websocket;
|
||||
|
||||
const SERVER: &str = r#"
|
||||
[server]
|
||||
|
@ -193,9 +194,8 @@ pub async fn jmap_tests() {
|
|||
//push_subscription::test(params.server.clone(), &mut params.client).await;
|
||||
//sieve_script::test(params.server.clone(), &mut params.client).await;
|
||||
//vacation_response::test(params.server.clone(), &mut params.client).await;
|
||||
email_submission::test(params.server.clone(), &mut params.client).await;
|
||||
|
||||
let websockets = "todo";
|
||||
//email_submission::test(params.server.clone(), &mut params.client).await;
|
||||
websocket::test(params.server.clone(), &mut params.client).await;
|
||||
|
||||
if delete {
|
||||
params.temp_dir.delete();
|
||||
|
|
148
tests/src/jmap/websocket.rs
Normal file
148
tests/src/jmap/websocket.rs
Normal file
|
@ -0,0 +1,148 @@
|
|||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use ahash::AHashSet;
|
||||
use futures::StreamExt;
|
||||
use jmap::JMAP;
|
||||
use jmap_client::{
|
||||
client::Client,
|
||||
client_ws::WebSocketMessage,
|
||||
core::{
|
||||
response::{Response, TaggedMethodResponse},
|
||||
set::SetObject,
|
||||
},
|
||||
TypeState,
|
||||
};
|
||||
use jmap_proto::types::id::Id;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::jmap::{mailbox::destroy_all_mailboxes, test_account_create, test_account_login};
|
||||
|
||||
pub async fn test(server: Arc<JMAP>, admin_client: &mut Client) {
|
||||
println!("Running WebSockets tests...");
|
||||
|
||||
// Authenticate all accounts
|
||||
let account_id = test_account_create(&server, "jdoe@example.com", "12345", "John Doe")
|
||||
.await
|
||||
.to_string();
|
||||
let client = test_account_login("jdoe@example.com", "12345").await;
|
||||
|
||||
let mut ws_stream = client.connect_ws().await.unwrap();
|
||||
|
||||
let (stream_tx, mut stream_rx) = mpsc::channel::<WebSocketMessage>(100);
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(change) = ws_stream.next().await {
|
||||
stream_tx.send(change.unwrap()).await.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
// Create mailbox
|
||||
let mut request = client.build();
|
||||
let create_id = request
|
||||
.set_mailbox()
|
||||
.create()
|
||||
.name("WebSocket Test")
|
||||
.create_id()
|
||||
.unwrap();
|
||||
let request_id = request.send_ws().await.unwrap();
|
||||
let mut response = expect_response(&mut stream_rx).await;
|
||||
assert_eq!(request_id, response.request_id().unwrap());
|
||||
let mailbox_id = response
|
||||
.pop_method_response()
|
||||
.unwrap()
|
||||
.unwrap_set_mailbox()
|
||||
.unwrap()
|
||||
.created(&create_id)
|
||||
.unwrap()
|
||||
.take_id();
|
||||
|
||||
// Enable push notifications
|
||||
client
|
||||
.enable_push_ws(None::<Vec<_>>, None::<&str>)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Make changes over standard HTTP and expect a push notification via WebSockets
|
||||
client
|
||||
.mailbox_update_sort_order(&mailbox_id, 1)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_state(&mut stream_rx, &[TypeState::Mailbox]).await;
|
||||
|
||||
// Multiple changes should be grouped and delivered in intervals
|
||||
for num in 0..5 {
|
||||
client
|
||||
.mailbox_update_sort_order(&mailbox_id, num)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
assert_state(&mut stream_rx, &[TypeState::Mailbox]).await;
|
||||
expect_nothing(&mut stream_rx).await;
|
||||
|
||||
// Disable push notifications
|
||||
client.disable_push_ws().await.unwrap();
|
||||
|
||||
// No more changes should be received
|
||||
let mut request = client.build();
|
||||
request.set_mailbox().destroy([&mailbox_id]);
|
||||
request.send_ws().await.unwrap();
|
||||
expect_response(&mut stream_rx)
|
||||
.await
|
||||
.pop_method_response()
|
||||
.unwrap()
|
||||
.unwrap_set_mailbox()
|
||||
.unwrap()
|
||||
.destroyed(&mailbox_id)
|
||||
.unwrap();
|
||||
expect_nothing(&mut stream_rx).await;
|
||||
|
||||
admin_client.set_default_account_id(account_id);
|
||||
destroy_all_mailboxes(admin_client).await;
|
||||
|
||||
server.store.assert_is_empty().await;
|
||||
}
|
||||
|
||||
async fn expect_response(
|
||||
stream_rx: &mut mpsc::Receiver<WebSocketMessage>,
|
||||
) -> Response<TaggedMethodResponse> {
|
||||
match tokio::time::timeout(Duration::from_millis(100), stream_rx.recv()).await {
|
||||
Ok(Some(message)) => match message {
|
||||
WebSocketMessage::Response(response) => response,
|
||||
_ => panic!("Expected response, got: {:?}", message),
|
||||
},
|
||||
result => {
|
||||
panic!("Timeout waiting for websocket: {:?}", result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn assert_state(stream_rx: &mut mpsc::Receiver<WebSocketMessage>, state: &[TypeState]) {
|
||||
match tokio::time::timeout(Duration::from_millis(700), stream_rx.recv()).await {
|
||||
Ok(Some(message)) => match message {
|
||||
WebSocketMessage::StateChange(changes) => {
|
||||
assert_eq!(
|
||||
changes
|
||||
.changes(&Id::new(1).to_string())
|
||||
.unwrap()
|
||||
.map(|x| x.0)
|
||||
.collect::<AHashSet<&TypeState>>(),
|
||||
state.iter().collect::<AHashSet<&TypeState>>()
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected state change, got: {:?}", message),
|
||||
},
|
||||
result => {
|
||||
panic!("Timeout waiting for websocket: {:?}", result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn expect_nothing(stream_rx: &mut mpsc::Receiver<WebSocketMessage>) {
|
||||
match tokio::time::timeout(Duration::from_millis(1000), stream_rx.recv()).await {
|
||||
Err(_) => {}
|
||||
message => {
|
||||
panic!("Received a message when expecting nothing: {:?}", message);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue