WebSocket tests passing

This commit is contained in:
Mauro D 2023-05-24 17:39:09 +00:00
parent 4ff2158783
commit 86a8a5f7d5
18 changed files with 1004 additions and 331 deletions

22
Cargo.lock generated
View file

@ -1686,7 +1686,9 @@ dependencies = [
"sqlx",
"store",
"tokio",
"tokio-tungstenite",
"tracing",
"tungstenite",
"utils",
]
@ -1702,6 +1704,7 @@ dependencies = [
"maybe-async 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)",
"parking_lot",
"reqwest",
"rustls 0.21.1",
"serde",
"serde_json",
"tokio",
@ -3931,18 +3934,17 @@ dependencies = [
[[package]]
name = "tokio-tungstenite"
version = "0.18.0"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd"
checksum = "ec509ac96e9a0c43427c74f003127d953a265737636129424288d27cb5c4b12c"
dependencies = [
"futures-util",
"log",
"rustls 0.20.8",
"rustls 0.21.1",
"tokio",
"tokio-rustls 0.23.4",
"tokio-rustls 0.24.0",
"tungstenite",
"webpki",
"webpki-roots 0.22.6",
"webpki-roots 0.23.0",
]
[[package]]
@ -4195,18 +4197,18 @@ checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
name = "tungstenite"
version = "0.18.0"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788"
checksum = "15fba1a6d6bb030745759a9a2a588bfe8490fc8b4751a277db3a0be1c9ebbf67"
dependencies = [
"base64 0.13.1",
"byteorder",
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand",
"rustls 0.20.8",
"rustls 0.21.1",
"sha1",
"thiserror",
"url",

View file

@ -3,6 +3,7 @@ pub mod echo;
pub mod method;
pub mod parser;
pub mod reference;
pub mod websocket;
use std::{
collections::HashMap,
@ -29,7 +30,7 @@ use crate::{
use self::{echo::Echo, method::MethodName};
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct Request {
pub using: u32,
pub method_calls: Vec<Call<RequestMethod>>,

View file

@ -40,152 +40,7 @@ impl Request {
let mut parser = Parser::new(json);
parser.next_token::<String>()?.assert(Token::DictStart)?;
while let Some(key) = parser.next_dict_key::<u128>()? {
match key {
0x0067_6e69_7375 => {
found_valid_keys = true;
parser.next_token::<Ignore>()?.assert(Token::ArrayStart)?;
loop {
match parser.next_token::<Capability>()? {
Token::String(capability) => {
request.using |= capability as u32;
}
Token::Comma => (),
Token::ArrayEnd => break,
token => {
return Err(token
.error("capability", &token.to_string())
.into())
}
}
}
}
0x0073_6c6c_6143_646f_6874_656d => {
found_valid_keys = true;
parser
.next_token::<Ignore>()?
.assert_jmap(Token::ArrayStart)?;
loop {
match parser.next_token::<Ignore>()? {
Token::ArrayStart => (),
Token::Comma => continue,
Token::ArrayEnd => break,
_ => {
return Err(RequestError::not_request("Invalid JMAP request"));
}
};
if request.method_calls.len() < max_calls {
let method_name = match parser.next_token::<MethodName>() {
Ok(Token::String(method)) => method,
Ok(_) => {
return Err(RequestError::not_request(
"Invalid JMAP request",
));
}
Err(Error::Method(MethodError::InvalidArguments(_))) => {
MethodName::error()
}
Err(err) => {
return Err(err.into());
}
};
parser.next_token::<Ignore>()?.assert_jmap(Token::Comma)?;
parser.ctx = method_name.obj;
let start_depth_array = parser.depth_array;
let start_depth_dict = parser.depth_dict;
let method = match (&method_name.fnc, &method_name.obj) {
(MethodFunction::Get, _) => {
if method_name.obj != MethodObject::SearchSnippet {
GetRequest::parse(&mut parser).map(RequestMethod::Get)
} else {
GetSearchSnippetRequest::parse(&mut parser)
.map(RequestMethod::SearchSnippet)
}
}
(MethodFunction::Query, _) => {
QueryRequest::parse(&mut parser).map(RequestMethod::Query)
}
(MethodFunction::Set, _) => {
SetRequest::parse(&mut parser).map(RequestMethod::Set)
}
(MethodFunction::Changes, _) => {
ChangesRequest::parse(&mut parser)
.map(RequestMethod::Changes)
}
(MethodFunction::QueryChanges, _) => {
QueryChangesRequest::parse(&mut parser)
.map(RequestMethod::QueryChanges)
}
(MethodFunction::Copy, MethodObject::Email) => {
CopyRequest::parse(&mut parser).map(RequestMethod::Copy)
}
(MethodFunction::Copy, MethodObject::Blob) => {
CopyBlobRequest::parse(&mut parser)
.map(RequestMethod::CopyBlob)
}
(MethodFunction::Import, MethodObject::Email) => {
ImportEmailRequest::parse(&mut parser)
.map(RequestMethod::ImportEmail)
}
(MethodFunction::Parse, MethodObject::Email) => {
ParseEmailRequest::parse(&mut parser)
.map(RequestMethod::ParseEmail)
}
(MethodFunction::Validate, MethodObject::SieveScript) => {
ValidateSieveScriptRequest::parse(&mut parser)
.map(RequestMethod::ValidateScript)
}
(MethodFunction::Echo, MethodObject::Core) => {
Echo::parse(&mut parser).map(RequestMethod::Echo)
}
_ => Err(Error::Method(MethodError::UnknownMethod(
method_name.to_string(),
))),
};
let method = match method {
Ok(method) => method,
Err(Error::Method(err)) => {
parser.skip_token(start_depth_array, start_depth_dict)?;
RequestMethod::Error(err)
}
Err(err) => {
return Err(err.into());
}
};
parser.next_token::<Ignore>()?.assert_jmap(Token::Comma)?;
let id = parser.next_token::<String>()?.unwrap_string("")?;
parser
.next_token::<Ignore>()?
.assert_jmap(Token::ArrayEnd)?;
request.method_calls.push(Call {
id,
method,
name: method_name,
});
} else {
return Err(RequestError::limit(RequestLimitError::CallsIn));
}
}
}
0x7364_4964_6574_6165_7263 => {
found_valid_keys = true;
let mut created_ids = HashMap::new();
parser.next_token::<Ignore>()?.assert(Token::DictStart)?;
while let Some(key) = parser.next_dict_key::<String>()? {
created_ids.insert(
key,
parser.next_token::<Id>()?.unwrap_string("createdIds")?,
);
}
request.created_ids = Some(created_ids);
}
_ => {
parser.skip_token(parser.depth_array, parser.depth_dict)?;
}
}
found_valid_keys |= request.parse_key(&mut parser, max_calls, key)?;
}
if found_valid_keys {
@ -197,6 +52,147 @@ impl Request {
Err(RequestError::limit(RequestLimitError::Size))
}
}
pub(crate) fn parse_key(
&mut self,
parser: &mut Parser,
max_calls: usize,
key: u128,
) -> Result<bool, RequestError> {
match key {
0x0067_6e69_7375 => {
parser.next_token::<Ignore>()?.assert(Token::ArrayStart)?;
loop {
match parser.next_token::<Capability>()? {
Token::String(capability) => {
self.using |= capability as u32;
}
Token::Comma => (),
Token::ArrayEnd => break,
token => return Err(token.error("capability", &token.to_string()).into()),
}
}
Ok(true)
}
0x0073_6c6c_6143_646f_6874_656d => {
parser
.next_token::<Ignore>()?
.assert_jmap(Token::ArrayStart)?;
loop {
match parser.next_token::<Ignore>()? {
Token::ArrayStart => (),
Token::Comma => continue,
Token::ArrayEnd => break,
_ => {
return Err(RequestError::not_request("Invalid JMAP request"));
}
};
if self.method_calls.len() < max_calls {
let method_name = match parser.next_token::<MethodName>() {
Ok(Token::String(method)) => method,
Ok(_) => {
return Err(RequestError::not_request("Invalid JMAP request"));
}
Err(Error::Method(MethodError::InvalidArguments(_))) => {
MethodName::error()
}
Err(err) => {
return Err(err.into());
}
};
parser.next_token::<Ignore>()?.assert_jmap(Token::Comma)?;
parser.ctx = method_name.obj;
let start_depth_array = parser.depth_array;
let start_depth_dict = parser.depth_dict;
let method = match (&method_name.fnc, &method_name.obj) {
(MethodFunction::Get, _) => {
if method_name.obj != MethodObject::SearchSnippet {
GetRequest::parse(parser).map(RequestMethod::Get)
} else {
GetSearchSnippetRequest::parse(parser)
.map(RequestMethod::SearchSnippet)
}
}
(MethodFunction::Query, _) => {
QueryRequest::parse(parser).map(RequestMethod::Query)
}
(MethodFunction::Set, _) => {
SetRequest::parse(parser).map(RequestMethod::Set)
}
(MethodFunction::Changes, _) => {
ChangesRequest::parse(parser).map(RequestMethod::Changes)
}
(MethodFunction::QueryChanges, _) => {
QueryChangesRequest::parse(parser).map(RequestMethod::QueryChanges)
}
(MethodFunction::Copy, MethodObject::Email) => {
CopyRequest::parse(parser).map(RequestMethod::Copy)
}
(MethodFunction::Copy, MethodObject::Blob) => {
CopyBlobRequest::parse(parser).map(RequestMethod::CopyBlob)
}
(MethodFunction::Import, MethodObject::Email) => {
ImportEmailRequest::parse(parser).map(RequestMethod::ImportEmail)
}
(MethodFunction::Parse, MethodObject::Email) => {
ParseEmailRequest::parse(parser).map(RequestMethod::ParseEmail)
}
(MethodFunction::Validate, MethodObject::SieveScript) => {
ValidateSieveScriptRequest::parse(parser)
.map(RequestMethod::ValidateScript)
}
(MethodFunction::Echo, MethodObject::Core) => {
Echo::parse(parser).map(RequestMethod::Echo)
}
_ => Err(Error::Method(MethodError::UnknownMethod(
method_name.to_string(),
))),
};
let method = match method {
Ok(method) => method,
Err(Error::Method(err)) => {
parser.skip_token(start_depth_array, start_depth_dict)?;
RequestMethod::Error(err)
}
Err(err) => {
return Err(err.into());
}
};
parser.next_token::<Ignore>()?.assert_jmap(Token::Comma)?;
let id = parser.next_token::<String>()?.unwrap_string("")?;
parser
.next_token::<Ignore>()?
.assert_jmap(Token::ArrayEnd)?;
self.method_calls.push(Call {
id,
method,
name: method_name,
});
} else {
return Err(RequestError::limit(RequestLimitError::CallsIn));
}
}
Ok(true)
}
0x7364_4964_6574_6165_7263 => {
let mut created_ids = HashMap::new();
parser.next_token::<Ignore>()?.assert(Token::DictStart)?;
while let Some(key) = parser.next_dict_key::<String>()? {
created_ids
.insert(key, parser.next_token::<Id>()?.unwrap_string("createdIds")?);
}
self.created_ids = Some(created_ids);
Ok(true)
}
_ => {
parser.skip_token(parser.depth_array, parser.depth_dict)?;
Ok(false)
}
}
}
}
impl From<Error> for RequestError {

View file

@ -0,0 +1,236 @@
use std::{borrow::Cow, collections::HashMap};
use crate::{
error::request::{RequestError, RequestErrorType, RequestLimitError},
parser::{json::Parser, Error, JsonObjectParser, Token},
request::Call,
response::{serialize::serialize_hex, Response, ResponseMethod},
types::{id::Id, state::State, type_state::TypeState},
};
use utils::map::vec_map::VecMap;
use super::{Request, RequestProperty};
#[derive(Debug)]
pub struct WebSocketRequest {
pub id: Option<String>,
pub request: Request,
}
#[derive(Debug, serde::Serialize)]
pub struct WebSocketResponse {
#[serde(rename = "@type")]
_type: WebSocketResponseType,
#[serde(rename = "methodResponses")]
method_responses: Vec<Call<ResponseMethod>>,
#[serde(rename = "sessionState")]
#[serde(serialize_with = "serialize_hex")]
session_state: u32,
#[serde(rename(deserialize = "createdIds"))]
#[serde(skip_serializing_if = "HashMap::is_empty")]
created_ids: HashMap<String, Id>,
#[serde(rename = "requestId")]
#[serde(skip_serializing_if = "Option::is_none")]
request_id: Option<String>,
}
#[derive(Debug, PartialEq, Eq, serde::Serialize)]
pub enum WebSocketResponseType {
Response,
}
#[derive(Debug, Default, PartialEq, Eq)]
pub struct WebSocketPushEnable {
pub data_types: Vec<TypeState>,
pub push_state: Option<String>,
}
#[derive(Debug)]
pub enum WebSocketMessage {
Request(WebSocketRequest),
PushEnable(WebSocketPushEnable),
PushDisable,
}
#[derive(serde::Serialize, Debug)]
pub enum WebSocketStateChangeType {
StateChange,
}
#[derive(serde::Serialize, Debug)]
pub struct WebSocketStateChange {
#[serde(rename = "@type")]
pub type_: WebSocketStateChangeType,
pub changed: VecMap<Id, VecMap<TypeState, State>>,
#[serde(rename = "pushState")]
#[serde(skip_serializing_if = "Option::is_none")]
push_state: Option<String>,
}
#[derive(Debug, serde::Serialize)]
pub struct WebSocketRequestError {
#[serde(rename = "@type")]
pub type_: WebSocketRequestErrorType,
#[serde(rename = "type")]
p_type: RequestErrorType,
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<RequestLimitError>,
status: u16,
detail: Cow<'static, str>,
#[serde(rename = "requestId")]
#[serde(skip_serializing_if = "Option::is_none")]
pub request_id: Option<String>,
}
#[derive(serde::Serialize, Debug)]
pub enum WebSocketRequestErrorType {
RequestError,
}
enum MessageType {
Request,
PushEnable,
PushDisable,
None,
}
impl WebSocketMessage {
pub fn parse(
json: &[u8],
max_calls: usize,
max_size: usize,
) -> Result<Self, WebSocketRequestError> {
if json.len() <= max_size {
let mut message_type = MessageType::None;
let mut request = WebSocketRequest {
id: None,
request: Request::default(),
};
let mut push_enable = WebSocketPushEnable::default();
let mut found_request_keys = false;
let mut found_push_keys = false;
let mut parser = Parser::new(json);
parser.next_token::<String>()?.assert(Token::DictStart)?;
while let Some(key) = parser.next_dict_key::<u128>()? {
match key {
0x0065_7079_7440 => {
let rt = parser
.next_token::<RequestProperty>()?
.unwrap_string("@type")?;
message_type = match (rt.hash[0], rt.hash[1]) {
(0x0074_7365_7571_6552, 0) => MessageType::Request,
(0x616e_4568_7375_5074_656b_636f_5362_6557, 0x656c62) => {
MessageType::PushEnable
}
(0x7369_4468_7375_5074_656b_636f_5362_6557, 0x656c6261) => {
MessageType::PushDisable
}
_ => MessageType::None,
};
}
0x0073_6570_7954_6174_6164 => {
push_enable.data_types =
<Option<Vec<TypeState>>>::parse(&mut parser)?.unwrap_or_default();
found_push_keys = true;
}
0x0065_7461_7453_6873_7570 => {
push_enable.push_state = parser
.next_token::<String>()?
.unwrap_string_or_null("pushState")?;
found_push_keys = true;
}
0x6469 => {
request.id = parser.next_token::<String>()?.unwrap_string_or_null("id")?;
}
_ => {
found_request_keys |=
request.request.parse_key(&mut parser, max_calls, key)?;
}
}
}
match message_type {
MessageType::Request if found_request_keys => {
Ok(WebSocketMessage::Request(request))
}
MessageType::PushEnable if found_push_keys => {
Ok(WebSocketMessage::PushEnable(push_enable))
}
MessageType::PushDisable if !found_request_keys && !found_push_keys => {
Ok(WebSocketMessage::PushDisable)
}
_ => Err(RequestError::not_request("Invalid WebSocket JMAP request").into()),
}
} else {
Err(RequestError::limit(RequestLimitError::Size).into())
}
}
}
impl WebSocketRequestError {
pub fn from_error(error: RequestError, request_id: Option<String>) -> Self {
Self {
type_: WebSocketRequestErrorType::RequestError,
p_type: error.p_type,
limit: error.limit,
status: error.status,
detail: error.detail,
request_id,
}
}
pub fn to_json(&self) -> String {
serde_json::to_string(self).unwrap()
}
}
impl From<RequestError> for WebSocketRequestError {
fn from(value: RequestError) -> Self {
Self::from_error(value, None)
}
}
impl From<Error> for WebSocketRequestError {
fn from(value: Error) -> Self {
RequestError::from(value).into()
}
}
impl WebSocketResponse {
pub fn from_response(response: Response, request_id: Option<String>) -> Self {
Self {
_type: WebSocketResponseType::Response,
method_responses: response.method_responses,
session_state: response.session_state,
created_ids: response.created_ids,
request_id,
}
}
pub fn to_json(&self) -> String {
serde_json::to_string(self).unwrap()
}
}
impl WebSocketStateChange {
pub fn new(push_state: Option<String>) -> Self {
WebSocketStateChange {
type_: WebSocketStateChangeType::StateChange,
changed: VecMap::new(),
push_state,
}
}
pub fn to_json(&self) -> String {
serde_json::to_string(self).unwrap()
}
}

View file

@ -51,6 +51,7 @@ pub struct Response {
pub session_state: u32,
#[serde(rename = "createdIds")]
#[serde(skip_serializing_if = "HashMap::is_empty")]
pub created_ids: HashMap<String, Id>,
}

View file

@ -34,6 +34,8 @@ p256 = { version = "0.13", features = ["ecdh"] }
hkdf = "0.12.3"
sha2 = "0.10.1"
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"]}
tokio-tungstenite = "0.19.0"
tungstenite = "0.19.0"
[dev-dependencies]
ece = "2.2"

View file

@ -101,6 +101,9 @@ impl crate::Config {
oauth_max_auth_attempts: settings.property_or_static("oauth.max-auth-attempts", "3")?,
event_source_throttle: settings
.property_or_static("jmap.event-source.throttle", "1s")?,
web_socket_throttle: settings.property_or_static("jmap.web-socket.throttle", "1s")?,
web_socket_timeout: settings.property_or_static("jmap.web-socket.timeout", "10m")?,
web_socket_heartbeat: settings.property_or_static("jmap.web-socket.heartbeat", "1m")?,
push_max_total: settings.property_or_static("jmap.push.max-total", "100")?,
};
config.add_capabilites(settings);

View file

@ -24,7 +24,7 @@ struct Ping {
impl JMAP {
pub async fn handle_event_source(
&self,
req: &HttpRequest,
req: HttpRequest,
acl_token: Arc<AclToken>,
) -> HttpResponse {
// Parse query

View file

@ -10,6 +10,7 @@ use hyper::{
};
use jmap_proto::{
error::request::{RequestError, RequestLimitError},
request::Request,
response::Response,
types::{blob::BlobId, id::Id},
};
@ -23,39 +24,96 @@ use crate::{
auth::oauth::OAuthMetadata,
blob::{DownloadResponse, UploadResponse},
services::state,
websocket::upgrade::upgrade_websocket_connection,
JMAP,
};
use super::{session::Session, HtmlResponse, HttpResponse, JmapSessionManager, JsonResponse};
use super::{
session::Session, HtmlResponse, HttpRequest, HttpResponse, JmapSessionManager, JsonResponse,
};
impl JMAP {
pub async fn parse_request(
&self,
req: &mut hyper::Request<hyper::body::Incoming>,
remote_ip: IpAddr,
instance: &Arc<ServerInstance>,
) -> HttpResponse {
let mut path = req.uri().path().split('/');
path.next();
pub async fn parse_jmap_request(
jmap: Arc<JMAP>,
mut req: HttpRequest,
remote_ip: IpAddr,
instance: Arc<ServerInstance>,
) -> HttpResponse {
let mut path = req.uri().path().split('/');
path.next();
match path.next().unwrap_or("") {
"jmap" => {
// Authenticate request
let (_in_flight, acl_token) = match self.authenticate_headers(req, remote_ip).await
{
Ok(Some(session)) => session,
Ok(None) => return RequestError::unauthorized().into_http_response(),
Err(err) => return err.into_http_response(),
};
match path.next().unwrap_or("") {
"jmap" => {
// Authenticate request
let (_in_flight, acl_token) = match jmap.authenticate_headers(&req, remote_ip).await {
Ok(Some(session)) => session,
Ok(None) => return RequestError::unauthorized().into_http_response(),
Err(err) => return err.into_http_response(),
};
match (path.next().unwrap_or(""), req.method()) {
("", &Method::POST) => {
return match fetch_body(req, self.config.request_max_size).await {
match (path.next().unwrap_or(""), req.method()) {
("", &Method::POST) => {
return match fetch_body(&mut req, jmap.config.request_max_size)
.await
.and_then(|bytes| {
Request::parse(
&bytes,
jmap.config.request_max_calls,
jmap.config.request_max_size,
)
}) {
Ok(request) => {
//let _ = println!("<- {}", String::from_utf8_lossy(&bytes));
match jmap.handle_request(request, acl_token, &instance).await {
Ok(response) => response.into_http_response(),
Err(err) => err.into_http_response(),
}
}
Err(err) => err.into_http_response(),
};
}
("download", &Method::GET) => {
if let (Some(_), Some(blob_id), Some(name)) = (
path.next().and_then(|p| Id::from_bytes(p.as_bytes())),
path.next().and_then(BlobId::from_base32),
path.next(),
) {
return match jmap.blob_download(&blob_id, &acl_token).await {
Ok(Some(blob)) => DownloadResponse {
filename: name.to_string(),
content_type: req
.uri()
.query()
.and_then(|q| {
form_urlencoded::parse(q.as_bytes())
.find(|(k, _)| k == "accept")
.map(|(_, v)| v.into_owned())
})
.unwrap_or("application/octet-stream".to_string()),
blob,
}
.into_http_response(),
Ok(None) => RequestError::not_found().into_http_response(),
Err(_) => RequestError::internal_server_error().into_http_response(),
};
}
}
("upload", &Method::POST) => {
if let Some(account_id) = path.next().and_then(|p| Id::from_bytes(p.as_bytes()))
{
return match fetch_body(&mut req, jmap.config.upload_max_size).await {
Ok(bytes) => {
//let delete = "fd";
//println!("<- {}", String::from_utf8_lossy(&bytes));
match self.handle_request(&bytes, acl_token, instance).await {
match jmap
.blob_upload(
account_id,
req.headers()
.get(CONTENT_TYPE)
.and_then(|h| h.to_str().ok())
.unwrap_or("application/octet-stream"),
&bytes,
)
.await
{
Ok(response) => response.into_http_response(),
Err(err) => err.into_http_response(),
}
@ -63,141 +121,90 @@ impl JMAP {
Err(err) => err.into_http_response(),
};
}
("download", &Method::GET) => {
if let (Some(_), Some(blob_id), Some(name)) = (
path.next().and_then(|p| Id::from_bytes(p.as_bytes())),
path.next().and_then(BlobId::from_base32),
path.next(),
) {
return match self.blob_download(&blob_id, &acl_token).await {
Ok(Some(blob)) => DownloadResponse {
filename: name.to_string(),
content_type: req
.uri()
.query()
.and_then(|q| {
form_urlencoded::parse(q.as_bytes())
.find(|(k, _)| k == "accept")
.map(|(_, v)| v.into_owned())
})
.unwrap_or("application/octet-stream".to_string()),
blob,
}
.into_http_response(),
Ok(None) => RequestError::not_found().into_http_response(),
Err(_) => {
RequestError::internal_server_error().into_http_response()
}
};
}
}
("upload", &Method::POST) => {
if let Some(account_id) =
path.next().and_then(|p| Id::from_bytes(p.as_bytes()))
{
return match fetch_body(req, self.config.upload_max_size).await {
Ok(bytes) => {
match self
.blob_upload(
account_id,
req.headers()
.get(CONTENT_TYPE)
.and_then(|h| h.to_str().ok())
.unwrap_or("application/octet-stream"),
&bytes,
)
.await
{
Ok(response) => response.into_http_response(),
Err(err) => err.into_http_response(),
}
}
Err(err) => err.into_http_response(),
};
}
}
("eventsource", &Method::GET) => {
return self.handle_event_source(req, acl_token).await
}
("ws", &Method::GET) => {
todo!()
}
_ => (),
}
}
".well-known" => match (path.next().unwrap_or(""), req.method()) {
("jmap", &Method::GET) => {
// Authenticate request
let (_in_flight, acl_token) =
match self.authenticate_headers(req, remote_ip).await {
Ok(Some(session)) => session,
Ok(None) => return RequestError::unauthorized().into_http_response(),
Err(err) => return err.into_http_response(),
};
return match self.handle_session_resource(instance, acl_token).await {
Ok(session) => session.into_http_response(),
Err(err) => err.into_http_response(),
};
("eventsource", &Method::GET) => {
return jmap.handle_event_source(req, acl_token).await
}
("oauth-authorization-server", &Method::GET) => {
let remote_addr = self.build_remote_addr(req, remote_ip);
// Limit anonymous requests
return match self.is_anonymous_allowed(remote_addr) {
Ok(_) => JsonResponse::new(OAuthMetadata::new(&instance.data))
.into_http_response(),
Err(err) => err.into_http_response(),
};
("ws", &Method::GET) => {
return upgrade_websocket_connection(jmap, req, acl_token, instance.clone())
.await;
}
_ => (),
},
"auth" => {
let remote_addr = self.build_remote_addr(req, remote_ip);
}
}
".well-known" => match (path.next().unwrap_or(""), req.method()) {
("jmap", &Method::GET) => {
// Authenticate request
let (_in_flight, acl_token) = match jmap.authenticate_headers(&req, remote_ip).await
{
Ok(Some(session)) => session,
Ok(None) => return RequestError::unauthorized().into_http_response(),
Err(err) => return err.into_http_response(),
};
match (path.next().unwrap_or(""), req.method()) {
("", &Method::GET) => {
return match self.is_anonymous_allowed(remote_addr) {
Ok(_) => self.handle_user_device_auth(req).await,
Err(err) => err.into_http_response(),
}
return match jmap.handle_session_resource(instance, acl_token).await {
Ok(session) => session.into_http_response(),
Err(err) => err.into_http_response(),
};
}
("oauth-authorization-server", &Method::GET) => {
let remote_addr = jmap.build_remote_addr(&req, remote_ip);
// Limit anonymous requests
return match jmap.is_anonymous_allowed(remote_addr) {
Ok(_) => {
JsonResponse::new(OAuthMetadata::new(&instance.data)).into_http_response()
}
("", &Method::POST) => {
return match self.is_auth_allowed(remote_addr) {
Ok(_) => self.handle_user_device_auth_post(req).await,
Err(err) => err.into_http_response(),
}
}
("code", &Method::GET) => {
return match self.is_anonymous_allowed(remote_addr) {
Ok(_) => self.handle_user_code_auth(req).await,
Err(err) => err.into_http_response(),
}
}
("code", &Method::POST) => {
return match self.is_auth_allowed(remote_addr) {
Ok(_) => self.handle_user_code_auth_post(req).await,
Err(err) => err.into_http_response(),
}
}
("device", &Method::POST) => {
return match self.is_anonymous_allowed(remote_addr) {
Ok(_) => self.handle_device_auth(req, instance).await,
Err(err) => err.into_http_response(),
}
}
("token", &Method::POST) => {
return match self.is_anonymous_allowed(remote_addr) {
Ok(_) => self.handle_token_request(req).await,
Err(err) => err.into_http_response(),
}
}
_ => (),
}
Err(err) => err.into_http_response(),
};
}
_ => (),
},
"auth" => {
let remote_addr = jmap.build_remote_addr(&req, remote_ip);
match (path.next().unwrap_or(""), req.method()) {
("", &Method::GET) => {
return match jmap.is_anonymous_allowed(remote_addr) {
Ok(_) => jmap.handle_user_device_auth(&mut req).await,
Err(err) => err.into_http_response(),
}
}
("", &Method::POST) => {
return match jmap.is_auth_allowed(remote_addr) {
Ok(_) => jmap.handle_user_device_auth_post(&mut req).await,
Err(err) => err.into_http_response(),
}
}
("code", &Method::GET) => {
return match jmap.is_anonymous_allowed(remote_addr) {
Ok(_) => jmap.handle_user_code_auth(&mut req).await,
Err(err) => err.into_http_response(),
}
}
("code", &Method::POST) => {
return match jmap.is_auth_allowed(remote_addr) {
Ok(_) => jmap.handle_user_code_auth_post(&mut req).await,
Err(err) => err.into_http_response(),
}
}
("device", &Method::POST) => {
return match jmap.is_anonymous_allowed(remote_addr) {
Ok(_) => jmap.handle_device_auth(&mut req, instance).await,
Err(err) => err.into_http_response(),
}
}
("token", &Method::POST) => {
return match jmap.is_anonymous_allowed(remote_addr) {
Ok(_) => jmap.handle_token_request(&mut req).await,
Err(err) => err.into_http_response(),
}
}
_ => (),
}
}
RequestError::not_found().into_http_response()
_ => (),
}
RequestError::not_found().into_http_response()
}
impl SessionManager for JmapSessionManager {
@ -246,7 +253,7 @@ impl SessionManager for JmapSessionManager {
}
}
async fn handle_request<T: AsyncRead + AsyncWrite + Unpin + 'static>(
async fn handle_request<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
jmap: Arc<JMAP>,
session: SessionData<T>,
) {
@ -257,27 +264,25 @@ async fn handle_request<T: AsyncRead + AsyncWrite + Unpin + 'static>(
.keep_alive(true)
.serve_connection(
session.stream,
service_fn(|mut req: hyper::Request<body::Incoming>| {
service_fn(|req: hyper::Request<body::Incoming>| {
let jmap = jmap.clone();
let span = span.clone();
let instance = session.instance.clone();
async move {
let response = jmap
.parse_request(&mut req, session.remote_ip, &instance)
.await;
tracing::debug!(
parent: &span,
event = "request",
uri = req.uri().to_string(),
status = response.status().to_string(),
);
let response = parse_jmap_request(jmap, req, session.remote_ip, instance).await;
Ok::<_, hyper::Error>(response)
}
}),
)
.with_upgrades()
.await
{
tracing::debug!(
@ -289,10 +294,7 @@ async fn handle_request<T: AsyncRead + AsyncWrite + Unpin + 'static>(
}
}
pub async fn fetch_body(
req: &mut hyper::Request<hyper::body::Incoming>,
max_size: usize,
) -> Result<Vec<u8>, RequestError> {
pub async fn fetch_body(req: &mut HttpRequest, max_size: usize) -> Result<Vec<u8>, RequestError> {
let mut bytes = Vec::with_capacity(1024);
while let Some(Ok(frame)) = req.frame().await {
if let Some(data) = frame.data_ref() {

View file

@ -17,15 +17,10 @@ use crate::{auth::AclToken, JMAP};
impl JMAP {
pub async fn handle_request(
&self,
bytes: &[u8],
request: Request,
acl_token: Arc<AclToken>,
instance: &Arc<ServerInstance>,
) -> Result<Response, RequestError> {
let request = Request::parse(
bytes,
self.config.request_max_calls,
self.config.request_max_size,
)?;
let mut response = Response::new(
acl_token.state(),
request.created_ids.unwrap_or_default(),

View file

@ -143,7 +143,7 @@ pub struct BaseCapabilities {
impl JMAP {
pub async fn handle_session_resource(
&self,
instance: &ServerInstance,
instance: Arc<ServerInstance>,
acl_token: Arc<AclToken>,
) -> Result<Session, RequestError> {
let mut session = Session::new(&instance.data, &self.config.capabilities);

View file

@ -31,7 +31,7 @@ impl JMAP {
pub async fn handle_device_auth(
&self,
req: &mut HttpRequest,
instance: &ServerInstance,
instance: Arc<ServerInstance>,
) -> HttpResponse {
// Parse form
let client_id = match parse_form_data(req)

View file

@ -46,6 +46,7 @@ pub mod sieve;
pub mod submission;
pub mod thread;
pub mod vacation;
pub mod websocket;
pub const SUPERUSER_ID: u32 = 0;
pub const LONG_SLUMBER: Duration = Duration::from_secs(60 * 60 * 24);
@ -103,6 +104,10 @@ pub struct Config {
pub event_source_throttle: Duration,
pub push_max_total: usize,
pub web_socket_throttle: Duration,
pub web_socket_timeout: Duration,
pub web_socket_heartbeat: Duration,
pub oauth_key: String,
pub oauth_expiry_user_code: u64,
pub oauth_expiry_auth_code: u64,

View file

@ -0,0 +1,2 @@
pub mod stream;
pub mod upgrade;

View file

@ -0,0 +1,192 @@
use std::{sync::Arc, time::Instant};
use futures_util::{SinkExt, StreamExt};
use hyper::upgrade::Upgraded;
use jmap_proto::{
error::request::RequestError,
request::websocket::{
WebSocketMessage, WebSocketRequestError, WebSocketResponse, WebSocketStateChange,
},
types::type_state::TypeState,
};
use tokio_tungstenite::WebSocketStream;
use tungstenite::Message;
use utils::{listener::ServerInstance, map::bitmap::Bitmap};
use crate::{auth::AclToken, JMAP};
impl JMAP {
pub async fn handle_websocket_stream(
&self,
mut stream: WebSocketStream<Upgraded>,
acl_token: Arc<AclToken>,
instance: Arc<ServerInstance>,
) {
let span = tracing::info_span!(
"WebSocket connection established",
"account_id" = acl_token.primary_id(),
"url" = instance.data,
);
// Set timeouts
let throttle = self.config.web_socket_throttle;
let timeout = self.config.web_socket_timeout;
let heartbeat = self.config.web_socket_heartbeat;
let mut last_request = Instant::now();
let mut last_changes_sent = Instant::now() - throttle;
let mut last_heartbeat = Instant::now() - heartbeat;
let mut next_event = heartbeat;
// Register with state manager
let mut change_rx = if let Some(change_rx) = self
.subscribe_state_manager(
acl_token.primary_id(),
acl_token.primary_id(),
Bitmap::all(),
)
.await
{
change_rx
} else {
let _ = stream
.send(Message::Text(
WebSocketRequestError::from(RequestError::internal_server_error()).to_json(),
))
.await;
return;
};
let mut changes = WebSocketStateChange::new(None);
let mut change_types: Bitmap<TypeState> = Bitmap::new();
loop {
tokio::select! {
event = tokio::time::timeout(next_event, stream.next()) => {
match event {
Ok(Some(Ok(event))) => {
match event {
Message::Text(text) => {
let response = match WebSocketMessage::parse(
text.as_bytes(),
self.config.request_max_calls,
self.config.request_max_size,
) {
Ok(WebSocketMessage::Request(request)) => {
match self
.handle_request(
request.request,
acl_token.clone(),
&instance,
)
.await
{
Ok(response) => {
WebSocketResponse::from_response(response, request.id)
.to_json()
}
Err(err) => {
WebSocketRequestError::from_error(err, request.id)
.to_json()
}
}
}
Ok(WebSocketMessage::PushEnable(push_enable)) => {
change_types = if !push_enable.data_types.is_empty() {
push_enable.data_types.into()
} else {
Bitmap::all()
};
continue;
}
Ok(WebSocketMessage::PushDisable) => {
change_types = Bitmap::new();
continue;
}
Err(err) => err.to_json(),
};
if let Err(err) = stream.send(Message::Text(response)).await {
tracing::debug!(parent: &span, error = ?err, "Failed to send text message");
}
}
Message::Ping(bytes) => {
if let Err(err) = stream.send(Message::Pong(bytes)).await {
tracing::debug!(parent: &span, error = ?err, "Failed to send pong message");
}
}
Message::Close(frame) => {
let _ = stream.close(frame).await;
break;
}
_ => (),
}
last_request = Instant::now();
last_heartbeat = Instant::now();
}
Ok(Some(Err(err))) => {
tracing::debug!(parent: &span, error = ?err, "Websocket error");
break;
}
Ok(None) => break,
Err(_) => {
// Verify timeout
if last_request.elapsed() > timeout {
tracing::debug!(
parent: &span,
event = "disconnect",
"Disconnecting idle client"
);
break;
}
}
}
}
state_change = change_rx.recv() => {
if let Some(state_change) = state_change {
if !change_types.is_empty() && state_change
.types
.iter()
.any(|(t, _)| change_types.contains(*t))
{
for (type_state, change_id) in state_change.types {
changes
.changed
.get_mut_or_insert(state_change.account_id.into())
.set(type_state, change_id.into());
}
}
} else {
tracing::debug!(
parent: &span,
event = "channel-closed",
"Disconnecting client, channel closed"
);
break;
}
}
}
if !changes.changed.is_empty() {
// Send any queued changes
let elapsed = last_changes_sent.elapsed();
if elapsed >= throttle {
if let Err(err) = stream.send(Message::Text(changes.to_json())).await {
tracing::debug!(parent: &span, error = ?err, "Failed to send state change message");
}
changes.changed.clear();
last_changes_sent = Instant::now();
last_heartbeat = Instant::now();
next_event = heartbeat;
} else {
next_event = throttle - elapsed;
}
} else if last_heartbeat.elapsed() > heartbeat {
if let Err(err) = stream.send(Message::Ping(vec![])).await {
tracing::debug!(parent: &span, error = ?err, "Failed to send ping message");
break;
}
last_heartbeat = Instant::now();
next_event = heartbeat;
}
}
}
}

View file

@ -0,0 +1,88 @@
use std::sync::Arc;
use http_body_util::{BodyExt, Full};
use hyper::{body::Bytes, Response, StatusCode};
use jmap_proto::error::request::RequestError;
use tokio_tungstenite::WebSocketStream;
use tungstenite::{handshake::derive_accept_key, protocol::Role};
use utils::listener::ServerInstance;
use crate::{
api::{http::ToHttpResponse, HttpRequest, HttpResponse},
auth::AclToken,
JMAP,
};
pub async fn upgrade_websocket_connection(
jmap: Arc<JMAP>,
req: HttpRequest,
acl_token: Arc<AclToken>,
instance: Arc<ServerInstance>,
) -> HttpResponse {
let headers = req.headers();
if headers
.get(hyper::header::CONNECTION)
.and_then(|h| h.to_str().ok())
!= Some("Upgrade")
|| headers
.get(hyper::header::UPGRADE)
.and_then(|h| h.to_str().ok())
!= Some("websocket")
{
return RequestError::blank(
StatusCode::BAD_REQUEST.as_u16(),
"WebSocket upgrade failed",
"Missing or Invalid Connection or Upgrade headers.",
)
.into_http_response();
}
let derived_key = match (
headers
.get("Sec-WebSocket-Key")
.and_then(|h| h.to_str().ok()),
headers
.get("Sec-WebSocket-Version")
.and_then(|h| h.to_str().ok()),
) {
(Some(key), Some(version)) if version == "13" => derive_accept_key(key.as_bytes()),
_ => {
return RequestError::blank(
StatusCode::BAD_REQUEST.as_u16(),
"WebSocket upgrade failed",
"Missing or Invalid Sec-WebSocket-Key headers.",
)
.into_http_response();
}
};
// Spawn WebSocket connection
tokio::spawn(async move {
// Upgrade connection
match hyper::upgrade::on(req).await {
Ok(upgraded) => {
jmap.handle_websocket_stream(
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
acl_token,
instance,
)
.await;
}
Err(e) => {
tracing::debug!("WebSocket upgrade failed: {}", e);
}
}
});
Response::builder()
.status(hyper::StatusCode::SWITCHING_PROTOCOLS)
.header(hyper::header::CONNECTION, "upgrade")
.header(hyper::header::UPGRADE, "websocket")
.header("Sec-WebSocket-Accept", &derived_key)
.header("Sec-WebSocket-Protocol", "jmap")
.body(
Full::new(Bytes::from("Switching to WebSocket protocol"))
.map_err(|never| match never {})
.boxed(),
)
.unwrap()
}

View file

@ -29,6 +29,7 @@ pub mod sieve_script;
pub mod thread_get;
pub mod thread_merge;
pub mod vacation_response;
pub mod websocket;
const SERVER: &str = r#"
[server]
@ -193,9 +194,8 @@ pub async fn jmap_tests() {
//push_subscription::test(params.server.clone(), &mut params.client).await;
//sieve_script::test(params.server.clone(), &mut params.client).await;
//vacation_response::test(params.server.clone(), &mut params.client).await;
email_submission::test(params.server.clone(), &mut params.client).await;
let websockets = "todo";
//email_submission::test(params.server.clone(), &mut params.client).await;
websocket::test(params.server.clone(), &mut params.client).await;
if delete {
params.temp_dir.delete();

148
tests/src/jmap/websocket.rs Normal file
View file

@ -0,0 +1,148 @@
use std::{sync::Arc, time::Duration};
use ahash::AHashSet;
use futures::StreamExt;
use jmap::JMAP;
use jmap_client::{
client::Client,
client_ws::WebSocketMessage,
core::{
response::{Response, TaggedMethodResponse},
set::SetObject,
},
TypeState,
};
use jmap_proto::types::id::Id;
use tokio::sync::mpsc;
use crate::jmap::{mailbox::destroy_all_mailboxes, test_account_create, test_account_login};
pub async fn test(server: Arc<JMAP>, admin_client: &mut Client) {
println!("Running WebSockets tests...");
// Authenticate all accounts
let account_id = test_account_create(&server, "jdoe@example.com", "12345", "John Doe")
.await
.to_string();
let client = test_account_login("jdoe@example.com", "12345").await;
let mut ws_stream = client.connect_ws().await.unwrap();
let (stream_tx, mut stream_rx) = mpsc::channel::<WebSocketMessage>(100);
tokio::spawn(async move {
while let Some(change) = ws_stream.next().await {
stream_tx.send(change.unwrap()).await.unwrap();
}
});
// Create mailbox
let mut request = client.build();
let create_id = request
.set_mailbox()
.create()
.name("WebSocket Test")
.create_id()
.unwrap();
let request_id = request.send_ws().await.unwrap();
let mut response = expect_response(&mut stream_rx).await;
assert_eq!(request_id, response.request_id().unwrap());
let mailbox_id = response
.pop_method_response()
.unwrap()
.unwrap_set_mailbox()
.unwrap()
.created(&create_id)
.unwrap()
.take_id();
// Enable push notifications
client
.enable_push_ws(None::<Vec<_>>, None::<&str>)
.await
.unwrap();
// Make changes over standard HTTP and expect a push notification via WebSockets
client
.mailbox_update_sort_order(&mailbox_id, 1)
.await
.unwrap();
assert_state(&mut stream_rx, &[TypeState::Mailbox]).await;
// Multiple changes should be grouped and delivered in intervals
for num in 0..5 {
client
.mailbox_update_sort_order(&mailbox_id, num)
.await
.unwrap();
}
tokio::time::sleep(Duration::from_millis(500)).await;
assert_state(&mut stream_rx, &[TypeState::Mailbox]).await;
expect_nothing(&mut stream_rx).await;
// Disable push notifications
client.disable_push_ws().await.unwrap();
// No more changes should be received
let mut request = client.build();
request.set_mailbox().destroy([&mailbox_id]);
request.send_ws().await.unwrap();
expect_response(&mut stream_rx)
.await
.pop_method_response()
.unwrap()
.unwrap_set_mailbox()
.unwrap()
.destroyed(&mailbox_id)
.unwrap();
expect_nothing(&mut stream_rx).await;
admin_client.set_default_account_id(account_id);
destroy_all_mailboxes(admin_client).await;
server.store.assert_is_empty().await;
}
async fn expect_response(
stream_rx: &mut mpsc::Receiver<WebSocketMessage>,
) -> Response<TaggedMethodResponse> {
match tokio::time::timeout(Duration::from_millis(100), stream_rx.recv()).await {
Ok(Some(message)) => match message {
WebSocketMessage::Response(response) => response,
_ => panic!("Expected response, got: {:?}", message),
},
result => {
panic!("Timeout waiting for websocket: {:?}", result);
}
}
}
async fn assert_state(stream_rx: &mut mpsc::Receiver<WebSocketMessage>, state: &[TypeState]) {
match tokio::time::timeout(Duration::from_millis(700), stream_rx.recv()).await {
Ok(Some(message)) => match message {
WebSocketMessage::StateChange(changes) => {
assert_eq!(
changes
.changes(&Id::new(1).to_string())
.unwrap()
.map(|x| x.0)
.collect::<AHashSet<&TypeState>>(),
state.iter().collect::<AHashSet<&TypeState>>()
);
}
_ => panic!("Expected state change, got: {:?}", message),
},
result => {
panic!("Timeout waiting for websocket: {:?}", result);
}
}
}
async fn expect_nothing(stream_rx: &mut mpsc::Receiver<WebSocketMessage>) {
match tokio::time::timeout(Duration::from_millis(1000), stream_rx.recv()).await {
Err(_) => {}
message => {
panic!("Received a message when expecting nothing: {:?}", message);
}
}
}