First API tests.

This commit is contained in:
Mauro D 2023-04-19 16:55:37 +00:00
parent 0264890595
commit f8acc4fe5a
57 changed files with 2243 additions and 664 deletions

View file

@ -1,18 +1,26 @@
#[package]
#name = "stalwart-jmap"
#description = "Stalwart JMAP Server"
#authors = [ "Stalwart Labs Ltd. <hello@stalw.art>"]
#repository = "https://github.com/stalwartlabs/jmap-server"
#homepage = "https://stalw.art/jmap"
#keywords = ["jmap", "email", "mail", "server"]
#categories = ["email"]
#license = "AGPL-3.0-only"
#version = "0.3.0"
#edition = "2021"
#resolver = "2"
[package]
name = "mail-server"
description = "Stalwart Mail Server"
authors = [ "Stalwart Labs Ltd. <hello@stalw.art>"]
repository = "https://github.com/stalwartlabs/jmap-server"
homepage = "https://stalw.art"
keywords = ["imap", "jmap", "smtp", "email", "mail", "server"]
categories = ["email"]
license = "AGPL-3.0-only"
version = "0.3.0"
edition = "2021"
resolver = "2"
#[lib]
#path = "crates/core/src/lib.rs"
[[bin]]
name = "stalwart-mail"
path = "crates/main/src/main.rs"
[dependencies]
store = { path = "crates/store" }
jmap = { path = "crates/jmap" }
jmap_proto = { path = "crates/jmap-proto" }
utils = { path = "crates/utils" }
tests = { path = "tests" }
[workspace]
members = [
@ -20,5 +28,6 @@ members = [
"crates/jmap-proto",
"crates/store",
"crates/utils",
"crates/maybe-async",
"tests",
]

View file

@ -2,6 +2,7 @@
name = "jmap_proto"
version = "0.1.0"
edition = "2021"
resolver = "2"
[dependencies]
store = { path = "../store" }
@ -11,3 +12,4 @@ fast-float = "0.2.0"
serde = { version = "1.0", features = ["derive"]}
ahash = { version = "0.8.0", features = ["serde"] }
serde_json = { version = "1.0", features = ["raw_value"] }
tracing = "0.1"

View file

@ -133,7 +133,7 @@ impl Serialize for MethodError {
"serverPartialFail",
concat!(
"Some, but not all, expected changes described by the method ",
"occurred. Please resynchronize to determine server state."
"occurred. Please resynchronize to determine server state."
),
),
MethodError::InvalidResultReference(description) => {
@ -164,10 +164,3 @@ impl Serialize for MethodError {
map.end()
}
}
impl From<store::Error> for MethodError {
fn from(_value: store::Error) -> Self {
let log = "true";
MethodError::ServerPartialFail
}
}

View file

@ -6,7 +6,7 @@ pub mod sieve;
use std::slice::Iter;
use store::{
write::{IntoBitmap, Operation, ToBitmaps},
write::{DeserializeFrom, SerializeInto, ToBitmaps},
Deserialize, Serialize,
};
use utils::{
@ -20,6 +20,7 @@ use crate::types::{
};
#[derive(Debug, Clone, Default, serde::Serialize, PartialEq, Eq)]
#[serde(transparent)]
pub struct Object<T> {
pub properties: VecMap<Property, T>,
}
@ -57,46 +58,14 @@ impl ToBitmaps for Value {
fn to_bitmaps(&self, ops: &mut Vec<store::write::Operation>, field: u8, set: bool) {
match self {
Value::Text(text) => text.as_str().to_bitmaps(ops, field, set),
Value::Keyword(keyword) => {
let (key, family) = keyword.into_bitmap();
ops.push(Operation::Bitmap {
family,
field,
key,
set,
});
}
Value::UnsignedInt(int) => {
let (key, family) = (*int as u32).into_bitmap();
ops.push(Operation::Bitmap {
family,
field,
key,
set,
});
}
Value::Keyword(keyword) => keyword.to_bitmaps(ops, field, set),
Value::UnsignedInt(int) => int.to_bitmaps(ops, field, set),
Value::List(items) => {
for item in items {
match item {
Value::Text(text) => text.as_str().to_bitmaps(ops, field, set),
Value::UnsignedInt(int) => {
let (key, family) = (*int as u32).into_bitmap();
ops.push(Operation::Bitmap {
family,
field,
key,
set,
});
}
Value::Keyword(keyword) => {
let (key, family) = keyword.into_bitmap();
ops.push(Operation::Bitmap {
family,
field,
key,
set,
})
}
Value::UnsignedInt(int) => int.to_bitmaps(ops, field, set),
Value::Keyword(keyword) => keyword.to_bitmaps(ops, field, set),
_ => (),
}
}
@ -106,6 +75,12 @@ impl ToBitmaps for Value {
}
}
impl ToBitmaps for Object<Value> {
fn to_bitmaps(&self, _ops: &mut Vec<store::write::Operation>, _field: u8, _set: bool) {
unreachable!()
}
}
const TEXT: u8 = 0;
const UNSIGNED_INT: u8 = 1;
const BOOL_TRUE: u8 = 2;
@ -123,14 +98,14 @@ const NULL: u8 = 12;
impl Serialize for Value {
fn serialize(self) -> Vec<u8> {
let mut buf = Vec::with_capacity(1024);
self.serialize_value(&mut buf);
self.serialize_into(&mut buf);
buf
}
}
impl Deserialize for Value {
fn deserialize(bytes: &[u8]) -> store::Result<Self> {
Self::deserialize_value(&mut bytes.iter())
Self::deserialize_from(&mut bytes.iter())
.ok_or_else(|| store::Error::InternalError("Failed to deserialize value.".to_string()))
}
}
@ -150,78 +125,71 @@ impl Deserialize for Object<Value> {
}
}
impl Object<Value> {
fn serialize_into(self, buf: &mut Vec<u8>) {
impl SerializeInto for Object<Value> {
fn serialize_into(&self, buf: &mut Vec<u8>) {
buf.push_leb128(self.properties.len());
for (k, v) in self.properties {
k.serialize_value(buf);
v.serialize_value(buf);
for (k, v) in &self.properties {
k.serialize_into(buf);
v.serialize_into(buf);
}
}
}
impl DeserializeFrom for Object<Value> {
fn deserialize_from(bytes: &mut Iter<'_, u8>) -> Option<Object<Value>> {
let len = bytes.next_leb128()?;
let mut properties = VecMap::with_capacity(len);
for _ in 0..len {
let key = Property::deserialize_value(bytes)?;
let value = Value::deserialize_value(bytes)?;
let key = Property::deserialize_from(bytes)?;
let value = Value::deserialize_from(bytes)?;
properties.append(key, value);
}
Some(Object { properties })
}
}
pub trait SerializeValue {
fn serialize_value(self, buf: &mut Vec<u8>);
}
pub trait DeserializeValue: Sized {
fn deserialize_value(bytes: &mut Iter<'_, u8>) -> Option<Self>;
}
impl SerializeValue for Value {
fn serialize_value(self, buf: &mut Vec<u8>) {
impl SerializeInto for Value {
fn serialize_into(&self, buf: &mut Vec<u8>) {
match self {
Value::Text(v) => {
buf.push(TEXT);
v.serialize_value(buf);
v.serialize_into(buf);
}
Value::UnsignedInt(v) => {
buf.push(UNSIGNED_INT);
v.serialize_value(buf);
v.serialize_into(buf);
}
Value::Bool(v) => {
buf.push(if v { BOOL_TRUE } else { BOOL_FALSE });
buf.push(if *v { BOOL_TRUE } else { BOOL_FALSE });
}
Value::Id(v) => {
buf.push(ID);
v.id().serialize_value(buf);
v.id().serialize_into(buf);
}
Value::Date(v) => {
buf.push(DATE);
(v.timestamp() as u64).serialize_value(buf);
(v.timestamp() as u64).serialize_into(buf);
}
Value::BlobId(v) => {
buf.push(BLOB_ID);
v.serialize_value(buf);
v.serialize_into(buf);
}
Value::Keyword(v) => {
buf.push(KEYWORD);
v.serialize_value(buf);
v.serialize_into(buf);
}
Value::TypeState(v) => {
buf.push(TYPE_STATE);
v.serialize_value(buf);
v.serialize_into(buf);
}
Value::Acl(v) => {
buf.push(ACL);
v.serialize_value(buf);
v.serialize_into(buf);
}
Value::List(v) => {
buf.push(LIST);
buf.push_leb128(v.len());
for i in v {
i.serialize_value(buf);
i.serialize_into(buf);
}
}
Value::Object(v) => {
@ -235,10 +203,10 @@ impl SerializeValue for Value {
}
}
impl DeserializeValue for Value {
fn deserialize_value(bytes: &mut Iter<'_, u8>) -> Option<Self> {
impl DeserializeFrom for Value {
fn deserialize_from(bytes: &mut Iter<'_, u8>) -> Option<Self> {
match *bytes.next()? {
TEXT => Some(Value::Text(String::deserialize_value(bytes)?)),
TEXT => Some(Value::Text(String::deserialize_from(bytes)?)),
UNSIGNED_INT => Some(Value::UnsignedInt(bytes.next_leb128()?)),
BOOL_TRUE => Some(Value::Bool(true)),
BOOL_FALSE => Some(Value::Bool(false)),
@ -246,15 +214,15 @@ impl DeserializeValue for Value {
DATE => Some(Value::Date(UTCDate::from_timestamp(
bytes.next_leb128::<u64>()? as i64,
))),
BLOB_ID => Some(Value::BlobId(BlobId::deserialize_value(bytes)?)),
KEYWORD => Some(Value::Keyword(Keyword::deserialize_value(bytes)?)),
TYPE_STATE => Some(Value::TypeState(TypeState::deserialize_value(bytes)?)),
ACL => Some(Value::Acl(Acl::deserialize_value(bytes)?)),
BLOB_ID => Some(Value::BlobId(BlobId::deserialize_from(bytes)?)),
KEYWORD => Some(Value::Keyword(Keyword::deserialize_from(bytes)?)),
TYPE_STATE => Some(Value::TypeState(TypeState::deserialize_from(bytes)?)),
ACL => Some(Value::Acl(Acl::deserialize_from(bytes)?)),
LIST => {
let len = bytes.next_leb128()?;
let mut items = Vec::with_capacity(len);
for _ in 0..len {
items.push(Value::deserialize_value(bytes)?);
items.push(Value::deserialize_from(bytes)?);
}
Some(Value::List(items))
}
@ -264,35 +232,3 @@ impl DeserializeValue for Value {
}
}
}
impl SerializeValue for String {
fn serialize_value(self, buf: &mut Vec<u8>) {
buf.push_leb128(self.len());
if !self.is_empty() {
buf.extend_from_slice(self.as_bytes());
}
}
}
impl DeserializeValue for String {
fn deserialize_value(bytes: &mut Iter<'_, u8>) -> Option<Self> {
let len: usize = bytes.next_leb128()?;
let mut s = Vec::with_capacity(len);
for _ in 0..len {
s.push(*bytes.next()?);
}
String::from_utf8(s).ok()
}
}
impl SerializeValue for u64 {
fn serialize_value(self, buf: &mut Vec<u8>) {
buf.push_leb128(self);
}
}
impl DeserializeValue for u64 {
fn deserialize_value(bytes: &mut Iter<'_, u8>) -> Option<Self> {
bytes.next_leb128()
}
}

View file

@ -35,7 +35,6 @@ impl<'x> Parser<'x> {
}
pub fn error(&self, message: &str) -> Error {
println!("{}", std::str::from_utf8(&self.bytes[self.pos..]).unwrap());
format!("{message} at position {}.", self.pos).into()
}

View file

@ -113,7 +113,7 @@ impl Display for MethodName {
}
impl MethodName {
pub fn unknown_method() -> Self {
pub fn error() -> Self {
Self {
obj: MethodObject::Thread,
fnc: MethodFunction::Echo,

View file

@ -27,7 +27,7 @@ use crate::{
types::id::Id,
};
use self::echo::Echo;
use self::{echo::Echo, method::MethodName};
#[derive(Debug)]
pub struct Request {
@ -36,9 +36,10 @@ pub struct Request {
pub created_ids: Option<HashMap<String, Id>>,
}
#[derive(Debug, serde::Serialize)]
#[derive(Debug)]
pub struct Call<T> {
pub id: String,
pub name: MethodName,
pub method: T,
}

View file

@ -60,7 +60,7 @@ impl Request {
parser
.next_token::<Ignore>()?
.assert_jmap(Token::ArrayStart)?;
let method = match parser.next_token::<MethodName>() {
let method_name = match parser.next_token::<MethodName>() {
Ok(Token::String(method)) => method,
Ok(_) => {
return Err(RequestError::not_request(
@ -68,18 +68,18 @@ impl Request {
));
}
Err(Error::Method(MethodError::InvalidArguments(_))) => {
MethodName::unknown_method()
MethodName::error()
}
Err(err) => {
return Err(err.into());
}
};
parser.next_token::<Ignore>()?.assert_jmap(Token::Comma)?;
parser.ctx = method.obj;
parser.ctx = method_name.obj;
let start_depth_array = parser.depth_array;
let start_depth_dict = parser.depth_dict;
let method = match (&method.fnc, &method.obj) {
let method = match (&method_name.fnc, &method_name.obj) {
(MethodFunction::Get, _) => {
GetRequest::parse(&mut parser).map(RequestMethod::Get)
}
@ -120,7 +120,7 @@ impl Request {
Echo::parse(&mut parser).map(RequestMethod::Echo)
}
_ => Err(Error::Method(MethodError::UnknownMethod(
method.to_string(),
method_name.to_string(),
))),
};
@ -140,7 +140,11 @@ impl Request {
parser
.next_token::<Ignore>()?
.assert_jmap(Token::ArrayEnd)?;
request.method_calls.push(Call { id, method });
request.method_calls.push(Call {
id,
method,
name: method_name,
});
} else {
return Err(RequestError::limit(RequestLimitError::CallsIn));
}

View file

@ -1,9 +1,8 @@
pub mod references;
pub mod serialize;
use std::collections::HashMap;
use serde::Serialize;
use crate::{
error::method::MethodError,
method::{
@ -18,11 +17,14 @@ use crate::{
set::SetResponse,
validate::ValidateSieveScriptResponse,
},
request::{echo::Echo, Call},
request::{echo::Echo, method::MethodName, Call},
types::id::Id,
};
use self::serialize::serialize_hex;
#[derive(Debug, serde::Serialize)]
#[serde(untagged)]
pub enum ResponseMethod {
Get(GetResponse),
Set(SetResponse),
@ -61,10 +63,16 @@ impl Response {
}
}
pub fn push_response(&mut self, id: String, method: impl Into<ResponseMethod>) {
pub fn push_response(
&mut self,
id: String,
name: MethodName,
method: impl Into<ResponseMethod>,
) {
self.method_responses.push(Call {
id,
method: method.into(),
name,
});
}
@ -73,13 +81,6 @@ impl Response {
}
}
pub fn serialize_hex<S>(value: &u32, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
format!("{:x}", value).serialize(serializer)
}
impl From<MethodError> for ResponseMethod {
fn from(error: MethodError) -> Self {
ResponseMethod::Error(error)

View file

@ -6,7 +6,6 @@ use crate::{
error::method::MethodError,
object::Object,
request::{
method::MethodFunction,
reference::{MaybeReference, ResultReference},
RequestMethod,
},
@ -166,9 +165,9 @@ impl Response {
fn eval_result_references(&self, rr: &ResultReference) -> EvalResult {
for response in &self.method_responses {
if response.id == rr.result_of {
match (&rr.name.fnc, &response.method) {
(MethodFunction::Get, ResponseMethod::Get(response)) => {
if response.id == rr.result_of && response.name == rr.name {
match &response.method {
ResponseMethod::Get(response) => {
return match rr.path.item_subquery() {
Some((root, property)) if root == "list" => {
let property = Property::parse(property);
@ -184,7 +183,7 @@ impl Response {
_ => EvalResult::Failed,
};
}
(MethodFunction::Changes, ResponseMethod::Changes(response)) => {
ResponseMethod::Changes(response) => {
return match rr.path.item_query() {
Some("created") => EvalResult::Values(
response
@ -208,7 +207,7 @@ impl Response {
_ => EvalResult::Failed,
};
}
(MethodFunction::Query, ResponseMethod::Query(response)) => {
ResponseMethod::Query(response) => {
return if rr.path.item_query() == Some("ids") {
EvalResult::Values(
response.ids.iter().copied().map(Into::into).collect(),
@ -217,7 +216,7 @@ impl Response {
EvalResult::Failed
};
}
(MethodFunction::QueryChanges, ResponseMethod::QueryChanges(response)) => {
ResponseMethod::QueryChanges(response) => {
return if rr.path.item_subquery() == Some(("added", "id")) {
EvalResult::Values(
response.added.iter().map(|item| item.id.into()).collect(),

View file

@ -0,0 +1,25 @@
use serde::{ser::SerializeSeq, Serialize};
use crate::request::Call;
use super::ResponseMethod;
impl Serialize for Call<ResponseMethod> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut seq = serializer.serialize_seq(3.into())?;
seq.serialize_element(&self.name.to_string())?;
seq.serialize_element(&self.method)?;
seq.serialize_element(&self.id)?;
seq.end()
}
}
pub fn serialize_hex<S>(value: &u32, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
format!("{:x}", value).serialize(serializer)
}

View file

@ -1,9 +1,8 @@
use std::fmt::{self, Display};
use crate::{
object::{DeserializeValue, SerializeValue},
parser::{json::Parser, JsonObjectParser},
};
use store::write::{DeserializeFrom, SerializeInto};
use crate::parser::{json::Parser, JsonObjectParser};
#[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Clone, Copy)]
#[repr(u8)]
@ -85,14 +84,14 @@ impl serde::Serialize for Acl {
}
}
impl SerializeValue for Acl {
fn serialize_value(self, buf: &mut Vec<u8>) {
buf.push(self as u8);
impl SerializeInto for Acl {
fn serialize_into(&self, buf: &mut Vec<u8>) {
buf.push(*self as u8);
}
}
impl DeserializeValue for Acl {
fn deserialize_value(bytes: &mut std::slice::Iter<'_, u8>) -> Option<Self> {
impl DeserializeFrom for Acl {
fn deserialize_from(bytes: &mut std::slice::Iter<'_, u8>) -> Option<Self> {
match *bytes.next()? {
0 => Some(Acl::Read),
1 => Some(Acl::Modify),

View file

@ -25,7 +25,7 @@ use std::{borrow::Borrow, io::Write};
use store::{
rand::{self, Rng},
write::now,
write::{now, DeserializeFrom, SerializeInto},
BlobKind,
};
use utils::codec::{
@ -33,10 +33,7 @@ use utils::codec::{
leb128::{Leb128Iterator, Leb128Writer},
};
use crate::{
object::{DeserializeValue, SerializeValue},
parser::{base32::JsonBase32Reader, json::Parser, JsonObjectParser},
};
use crate::parser::{base32::JsonBase32Reader, json::Parser, JsonObjectParser};
use super::date::UTCDate;
@ -155,7 +152,7 @@ impl BlobId {
.into()
}
fn serialize_into(&self, writer: &mut (impl Write + Leb128Writer)) {
fn serialize_as(&self, writer: &mut (impl Write + Leb128Writer)) {
let kind = self
.section
.as_ref()
@ -264,19 +261,19 @@ impl std::fmt::Display for BlobId {
#[allow(clippy::unused_io_amount)]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut writer = Base32Writer::with_capacity(std::mem::size_of::<BlobId>() * 2);
self.serialize_into(&mut writer);
self.serialize_as(&mut writer);
f.write_str(&writer.finalize())
}
}
impl SerializeValue for BlobId {
fn serialize_value(self, buf: &mut Vec<u8>) {
self.serialize_into(buf)
impl SerializeInto for BlobId {
fn serialize_into(&self, buf: &mut Vec<u8>) {
self.serialize_as(buf)
}
}
impl DeserializeValue for BlobId {
fn deserialize_value(bytes: &mut std::slice::Iter<'_, u8>) -> Option<Self> {
impl DeserializeFrom for BlobId {
fn deserialize_from(bytes: &mut std::slice::Iter<'_, u8>) -> Option<Self> {
BlobId::from_iter(bytes)
}
}

View file

@ -1,3 +1,5 @@
use std::fmt::{self, Display, Formatter};
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[repr(u8)]
pub enum Collection {
@ -32,3 +34,18 @@ impl From<Collection> for u8 {
v as u8
}
}
impl Display for Collection {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Collection::Principal => write!(f, "principal"),
Collection::PushSubscription => write!(f, "pushSubscription"),
Collection::Email => write!(f, "email"),
Collection::Mailbox => write!(f, "mailbox"),
Collection::Thread => write!(f, "thread"),
Collection::Identity => write!(f, "identity"),
Collection::EmailSubmission => write!(f, "emailSubmission"),
Collection::SieveScript => write!(f, "sieveScript"),
}
}
}

View file

@ -23,7 +23,6 @@
use std::ops::Deref;
use store::{write::IntoBitmap, Serialize, BM_TAG, TAG_ID};
use utils::codec::base32_custom::{BASE32_ALPHABET, BASE32_INVERSE};
use crate::{
@ -252,12 +251,6 @@ impl From<Id> for String {
}
}
impl IntoBitmap for Id {
fn into_bitmap(self) -> (Vec<u8>, u8) {
(self.serialize(), BM_TAG | TAG_ID)
}
}
impl serde::Serialize for Id {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where

View file

@ -1,26 +1,26 @@
use std::fmt::Display;
use store::{write::IntoBitmap, BM_TAG, TAG_STATIC, TAG_TEXT};
use store::{
write::{BitmapFamily, DeserializeFrom, Operation, SerializeInto, ToBitmaps},
Serialize, BM_TAG, TAG_STATIC, TAG_TEXT,
};
use utils::codec::leb128::{Leb128Iterator, Leb128Vec};
use crate::{
object::{DeserializeValue, SerializeValue},
parser::{json::Parser, JsonObjectParser},
};
use crate::parser::{json::Parser, JsonObjectParser};
pub const SEEN: u8 = 0;
pub const DRAFT: u8 = 1;
pub const FLAGGED: u8 = 2;
pub const ANSWERED: u8 = 3;
pub const RECENT: u8 = 4;
pub const IMPORTANT: u8 = 5;
pub const PHISHING: u8 = 6;
pub const JUNK: u8 = 7;
pub const NOTJUNK: u8 = 8;
pub const DELETED: u8 = 9;
pub const FORWARDED: u8 = 10;
pub const MDN_SENT: u8 = 11;
pub const OTHER: u8 = 12;
pub const SEEN: usize = 0;
pub const DRAFT: usize = 1;
pub const FLAGGED: usize = 2;
pub const ANSWERED: usize = 3;
pub const RECENT: usize = 4;
pub const IMPORTANT: usize = 5;
pub const PHISHING: usize = 6;
pub const JUNK: usize = 7;
pub const NOTJUNK: usize = 8;
pub const DELETED: usize = 9;
pub const FORWARDED: usize = 10;
pub const MDN_SENT: usize = 11;
pub const OTHER: usize = 12;
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize)]
#[serde(untagged)]
@ -122,63 +122,84 @@ impl Display for Keyword {
}
}
impl IntoBitmap for &Keyword {
fn into_bitmap(self) -> (Vec<u8>, u8) {
match self {
Keyword::Seen => (vec![SEEN], BM_TAG | TAG_STATIC),
Keyword::Draft => (vec![DRAFT], BM_TAG | TAG_STATIC),
Keyword::Flagged => (vec![FLAGGED], BM_TAG | TAG_STATIC),
Keyword::Answered => (vec![ANSWERED], BM_TAG | TAG_STATIC),
Keyword::Recent => (vec![RECENT], BM_TAG | TAG_STATIC),
Keyword::Important => (vec![IMPORTANT], BM_TAG | TAG_STATIC),
Keyword::Phishing => (vec![PHISHING], BM_TAG | TAG_STATIC),
Keyword::Junk => (vec![JUNK], BM_TAG | TAG_STATIC),
Keyword::NotJunk => (vec![NOTJUNK], BM_TAG | TAG_STATIC),
Keyword::Deleted => (vec![DELETED], BM_TAG | TAG_STATIC),
Keyword::Forwarded => (vec![FORWARDED], BM_TAG | TAG_STATIC),
Keyword::MdnSent => (vec![MDN_SENT], BM_TAG | TAG_STATIC),
Keyword::Other(string) => (string.as_bytes().to_vec(), BM_TAG | TAG_TEXT),
impl BitmapFamily for Keyword {
fn family(&self) -> u8 {
if matches!(self, Keyword::Other(_)) {
BM_TAG | TAG_TEXT
} else {
BM_TAG | TAG_STATIC
}
}
}
impl IntoBitmap for Keyword {
fn into_bitmap(self) -> (Vec<u8>, u8) {
impl ToBitmaps for Keyword {
fn to_bitmaps(&self, ops: &mut Vec<store::write::Operation>, field: u8, set: bool) {
ops.push(Operation::Bitmap {
family: self.family(),
field,
key: self.serialize(),
set,
});
}
}
impl Serialize for Keyword {
fn serialize(self) -> Vec<u8> {
match self {
Keyword::Seen => (vec![SEEN], BM_TAG | TAG_STATIC),
Keyword::Draft => (vec![DRAFT], BM_TAG | TAG_STATIC),
Keyword::Flagged => (vec![FLAGGED], BM_TAG | TAG_STATIC),
Keyword::Answered => (vec![ANSWERED], BM_TAG | TAG_STATIC),
Keyword::Recent => (vec![RECENT], BM_TAG | TAG_STATIC),
Keyword::Important => (vec![IMPORTANT], BM_TAG | TAG_STATIC),
Keyword::Phishing => (vec![PHISHING], BM_TAG | TAG_STATIC),
Keyword::Junk => (vec![JUNK], BM_TAG | TAG_STATIC),
Keyword::NotJunk => (vec![NOTJUNK], BM_TAG | TAG_STATIC),
Keyword::Deleted => (vec![DELETED], BM_TAG | TAG_STATIC),
Keyword::Forwarded => (vec![FORWARDED], BM_TAG | TAG_STATIC),
Keyword::MdnSent => (vec![MDN_SENT], BM_TAG | TAG_STATIC),
Keyword::Other(string) => (string.into_bytes(), BM_TAG | TAG_TEXT),
Keyword::Seen => vec![SEEN as u8],
Keyword::Draft => vec![DRAFT as u8],
Keyword::Flagged => vec![FLAGGED as u8],
Keyword::Answered => vec![ANSWERED as u8],
Keyword::Recent => vec![RECENT as u8],
Keyword::Important => vec![IMPORTANT as u8],
Keyword::Phishing => vec![PHISHING as u8],
Keyword::Junk => vec![JUNK as u8],
Keyword::NotJunk => vec![NOTJUNK as u8],
Keyword::Deleted => vec![DELETED as u8],
Keyword::Forwarded => vec![FORWARDED as u8],
Keyword::MdnSent => vec![MDN_SENT as u8],
Keyword::Other(string) => string.into_bytes(),
}
}
}
impl SerializeValue for Keyword {
fn serialize_value(self, buf: &mut Vec<u8>) {
impl Serialize for &Keyword {
fn serialize(self) -> Vec<u8> {
match self {
Keyword::Seen => buf.push(SEEN),
Keyword::Draft => buf.push(DRAFT),
Keyword::Flagged => buf.push(FLAGGED),
Keyword::Answered => buf.push(ANSWERED),
Keyword::Recent => buf.push(RECENT),
Keyword::Important => buf.push(IMPORTANT),
Keyword::Phishing => buf.push(PHISHING),
Keyword::Junk => buf.push(JUNK),
Keyword::NotJunk => buf.push(NOTJUNK),
Keyword::Deleted => buf.push(DELETED),
Keyword::Forwarded => buf.push(FORWARDED),
Keyword::MdnSent => buf.push(MDN_SENT),
Keyword::Seen => vec![SEEN as u8],
Keyword::Draft => vec![DRAFT as u8],
Keyword::Flagged => vec![FLAGGED as u8],
Keyword::Answered => vec![ANSWERED as u8],
Keyword::Recent => vec![RECENT as u8],
Keyword::Important => vec![IMPORTANT as u8],
Keyword::Phishing => vec![PHISHING as u8],
Keyword::Junk => vec![JUNK as u8],
Keyword::NotJunk => vec![NOTJUNK as u8],
Keyword::Deleted => vec![DELETED as u8],
Keyword::Forwarded => vec![FORWARDED as u8],
Keyword::MdnSent => vec![MDN_SENT as u8],
Keyword::Other(string) => string.as_bytes().to_vec(),
}
}
}
impl SerializeInto for Keyword {
fn serialize_into(&self, buf: &mut Vec<u8>) {
match self {
Keyword::Seen => buf.push(SEEN as u8),
Keyword::Draft => buf.push(DRAFT as u8),
Keyword::Flagged => buf.push(FLAGGED as u8),
Keyword::Answered => buf.push(ANSWERED as u8),
Keyword::Recent => buf.push(RECENT as u8),
Keyword::Important => buf.push(IMPORTANT as u8),
Keyword::Phishing => buf.push(PHISHING as u8),
Keyword::Junk => buf.push(JUNK as u8),
Keyword::NotJunk => buf.push(NOTJUNK as u8),
Keyword::Deleted => buf.push(DELETED as u8),
Keyword::Forwarded => buf.push(FORWARDED as u8),
Keyword::MdnSent => buf.push(MDN_SENT as u8),
Keyword::Other(string) => {
buf.push_leb128(OTHER as usize + string.len());
buf.push_leb128(OTHER + string.len());
if !string.is_empty() {
buf.extend_from_slice(string.as_bytes())
}
@ -187,9 +208,9 @@ impl SerializeValue for Keyword {
}
}
impl DeserializeValue for Keyword {
fn deserialize_value(bytes: &mut std::slice::Iter<'_, u8>) -> Option<Self> {
match *bytes.next()? {
impl DeserializeFrom for Keyword {
fn deserialize_from(bytes: &mut std::slice::Iter<'_, u8>) -> Option<Self> {
match bytes.next_leb128::<usize>()? {
SEEN => Some(Keyword::Seen),
DRAFT => Some(Keyword::Draft),
FLAGGED => Some(Keyword::Flagged),
@ -202,8 +223,8 @@ impl DeserializeValue for Keyword {
DELETED => Some(Keyword::Deleted),
FORWARDED => Some(Keyword::Forwarded),
MDN_SENT => Some(Keyword::MdnSent),
_ => {
let len = bytes.next_leb128::<usize>()? - OTHER as usize;
other => {
let len = other - OTHER;
let mut keyword = Vec::with_capacity(len);
for _ in 0..len {
keyword.push(*bytes.next()?);

View file

@ -2,15 +2,13 @@ use std::fmt::{Display, Formatter};
use mail_parser::RfcHeader;
use serde::Serialize;
use store::write::{DeserializeFrom, SerializeInto};
use crate::{
object::{DeserializeValue, SerializeValue},
parser::{json::Parser, Error, JsonObjectParser},
};
use crate::parser::{json::Parser, Error, JsonObjectParser};
use super::{acl::Acl, id::Id, keyword::Keyword, value::Value};
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize)]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub enum Property {
Acl,
Aliases,
@ -801,14 +799,14 @@ impl IntoProperty for String {
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize)]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct HeaderProperty {
pub form: HeaderForm,
pub header: String,
pub all: bool,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize)]
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum HeaderForm {
Raw,
Text,
@ -845,8 +843,8 @@ impl Display for HeaderForm {
}
}
impl From<Property> for u8 {
fn from(value: Property) -> Self {
impl From<&Property> for u8 {
fn from(value: &Property) -> Self {
match value {
Property::IsActive => 0,
Property::IsEnabled => 1,
@ -950,6 +948,12 @@ impl From<Property> for u8 {
}
}
impl From<Property> for u8 {
fn from(value: Property) -> Self {
(&value).into()
}
}
impl From<RfcHeader> for Property {
fn from(value: RfcHeader) -> Self {
match value {
@ -969,14 +973,14 @@ impl From<RfcHeader> for Property {
}
}
impl SerializeValue for Property {
fn serialize_value(self, buf: &mut Vec<u8>) {
impl SerializeInto for Property {
fn serialize_into(&self, buf: &mut Vec<u8>) {
buf.push(self.into());
}
}
impl DeserializeValue for Property {
fn deserialize_value(bytes: &mut std::slice::Iter<'_, u8>) -> Option<Self> {
impl DeserializeFrom for Property {
fn deserialize_from(bytes: &mut std::slice::Iter<'_, u8>) -> Option<Self> {
match *bytes.next()? {
0 => Some(Property::IsActive),
1 => Some(Property::IsEnabled),
@ -1084,3 +1088,18 @@ impl DeserializeValue for Property {
}
}
}
impl Serialize for Property {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl AsRef<Property> for Property {
fn as_ref(&self) -> &Property {
self
}
}

View file

@ -1,11 +1,9 @@
use std::fmt::Display;
use serde::Serialize;
use store::write::{DeserializeFrom, SerializeInto};
use crate::{
object::{DeserializeValue, SerializeValue},
parser::{json::Parser, JsonObjectParser},
};
use crate::parser::{json::Parser, JsonObjectParser};
#[derive(Debug, Eq, PartialEq, Hash, Clone, Copy, Serialize)]
#[repr(u8)]
@ -72,14 +70,14 @@ impl Display for TypeState {
}
}
impl SerializeValue for TypeState {
fn serialize_value(self, buf: &mut Vec<u8>) {
buf.push(self as u8);
impl SerializeInto for TypeState {
fn serialize_into(&self, buf: &mut Vec<u8>) {
buf.push(*self as u8);
}
}
impl DeserializeValue for TypeState {
fn deserialize_value(bytes: &mut std::slice::Iter<'_, u8>) -> Option<Self> {
impl DeserializeFrom for TypeState {
fn deserialize_from(bytes: &mut std::slice::Iter<'_, u8>) -> Option<Self> {
match *bytes.next()? {
0 => Some(TypeState::Email),
1 => Some(TypeState::EmailDelivery),

View file

@ -22,6 +22,7 @@ use super::{
};
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize)]
#[serde(untagged)]
pub enum Value {
Text(String),
UnsignedInt(u64),
@ -175,7 +176,6 @@ impl<T: JsonObjectParser + Display + Eq> JsonObjectParser for SetValueMap<T> {
let mut values = Vec::new();
match parser.next_token::<Ignore>()? {
Token::DictStart => {
parser.next_token::<Ignore>()?.assert(Token::DictStart)?;
while {
let value = parser.next_dict_key::<T>()?;
if bool::parse(parser)? {

View file

@ -2,6 +2,7 @@
name = "jmap"
version = "0.1.0"
edition = "2021"
resolver = "2"
[dependencies]
store = { path = "../store" }
@ -10,6 +11,7 @@ utils = { path = "../utils" }
mail-parser = { git = "https://github.com/stalwartlabs/mail-parser", features = ["full_encoding", "serde_support", "ludicrous_mode"] }
mail-builder = { git = "https://github.com/stalwartlabs/mail-builder", features = ["ludicrous_mode"] }
mail-send = { git = "https://github.com/stalwartlabs/mail-send" }
sieve-rs = { git = "https://github.com/stalwartlabs/sieve" }
serde = { version = "1.0", features = ["derive"]}
serde_json = "1.0"
hyper = { version = "1.0.0-rc.3", features = ["server", "http1", "http2"] }

View file

@ -0,0 +1,57 @@
use store::fts::Language;
use super::session::BaseCapabilities;
impl crate::Config {
pub fn new(settings: &utils::config::Config) -> Result<Self, String> {
let mut config = Self {
default_language: Language::from_iso_639(
settings.value("jmap.fts.default-language").unwrap_or("en"),
)
.unwrap_or(Language::English),
query_max_results: settings
.property("jmap.protocol.query.max-results")?
.unwrap_or(5000),
request_max_size: settings
.property("jmap.protocol.request.max-size")?
.unwrap_or(10000000),
request_max_calls: settings
.property("jmap.protocol.request.max-calls")?
.unwrap_or(16),
request_max_concurrent: settings
.property("jmap.protocol.request.max-concurrent")?
.unwrap_or(4),
request_max_concurrent_total: settings
.property("jmap.protocol.request.max-concurrent-total")?
.unwrap_or(4),
get_max_objects: settings
.property("jmap.protocol.get.max-objects")?
.unwrap_or(500),
set_max_objects: settings
.property("jmap.protocol.set.max-objects")?
.unwrap_or(500),
upload_max_size: settings
.property("jmap.protocol.upload.max-size")?
.unwrap_or(50000000),
upload_max_concurrent: settings
.property("jmap.protocol.upload.max-concurrent")?
.unwrap_or(4),
mailbox_max_depth: settings.property("jmap.mailbox.max-depth")?.unwrap_or(10),
mailbox_name_max_len: settings
.property("jmap.mailbox.max-name-length")?
.unwrap_or(255),
mail_attachments_max_size: settings
.property("jmap.email.max-attachment-size")?
.unwrap_or(50000000),
sieve_max_script_name: settings
.property("jmap.sieve.max-name-length")?
.unwrap_or(512),
sieve_max_scripts: settings
.property("jmap.protocol.max-scripts")?
.unwrap_or(256),
capabilities: BaseCapabilities::default(),
};
config.add_capabilites(settings);
Ok(config)
}
}

View file

@ -17,17 +17,20 @@ use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};
use utils::listener::{SessionData, SessionManager};
use utils::listener::{ServerInstance, SessionData, SessionManager};
use crate::{
blob::{DownloadResponse, UploadResponse},
JMAP,
};
use super::session::Session;
impl JMAP {
pub async fn parse_request(
&self,
req: &mut hyper::Request<hyper::body::Incoming>,
instance: &ServerInstance,
) -> hyper::Response<BoxBody<Bytes, hyper::Error>> {
let mut path = req.uri().path().split('/');
path.next();
@ -65,7 +68,15 @@ impl JMAP {
}
.into_http_response(),
Ok(None) => RequestError::not_found().into_http_response(),
Err(err) => RequestError::internal_server_error().into_http_response(),
Err(err) => {
tracing::error!(event = "error",
context = "blob_store",
account_id = account_id.document_id(),
blob_id = ?blob_id,
error = ?err,
"Failed to download blob");
RequestError::internal_server_error().into_http_response()
}
};
}
}
@ -103,7 +114,10 @@ impl JMAP {
},
".well-known" => match (path.next().unwrap_or(""), req.method()) {
("jmap", &Method::GET) => {
todo!()
return match self.handle_session_resource(instance).await {
Ok(session) => session.into_http_response(),
Err(err) => err.into_http_response(),
};
}
("oauth-authorization-server", &Method::GET) => {
todo!()
@ -155,7 +169,6 @@ impl SessionManager for super::SessionManager {
span,
in_flight: session.in_flight,
instance: session.instance,
shutdown_rx: session.shutdown_rx,
},
)
.await;
@ -175,6 +188,10 @@ impl SessionManager for super::SessionManager {
}
});
}
fn max_concurrent(&self) -> u64 {
self.inner.config.request_max_concurrent_total
}
}
async fn handle_request<T: AsyncRead + AsyncWrite + Unpin + 'static>(
@ -182,6 +199,8 @@ async fn handle_request<T: AsyncRead + AsyncWrite + Unpin + 'static>(
session: SessionData<T>,
) {
let span = session.span;
let _in_flight = session.in_flight;
if let Err(http_err) = http1::Builder::new()
.keep_alive(true)
.serve_connection(
@ -189,9 +208,10 @@ async fn handle_request<T: AsyncRead + AsyncWrite + Unpin + 'static>(
service_fn(|mut req: hyper::Request<body::Incoming>| {
let jmap = jmap.clone();
let span = span.clone();
let instance = session.instance.clone();
async move {
let response = jmap.parse_request(&mut req).await;
let response = jmap.parse_request(&mut req, &instance).await;
tracing::debug!(
parent: &span,
@ -208,7 +228,8 @@ async fn handle_request<T: AsyncRead + AsyncWrite + Unpin + 'static>(
{
tracing::debug!(
parent: &span,
event = "http-error",
event = "error",
context = "http",
reason = %http_err,
);
}
@ -237,6 +258,24 @@ trait ToHttpResponse {
impl ToHttpResponse for Response {
fn into_http_response(self) -> hyper::Response<BoxBody<Bytes, hyper::Error>> {
let delete = "";
println!("-> {}", serde_json::to_string_pretty(&self).unwrap());
hyper::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json; charset=utf-8")
.body(
Full::new(Bytes::from(serde_json::to_string(&self).unwrap()))
.map_err(|never| match never {})
.boxed(),
)
.unwrap()
}
}
impl ToHttpResponse for Session {
fn into_http_response(self) -> hyper::Response<BoxBody<Bytes, hyper::Error>> {
let delete = "";
println!("-> {}", serde_json::to_string_pretty(&self).unwrap());
hyper::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json; charset=utf-8")
@ -276,6 +315,9 @@ impl ToHttpResponse for DownloadResponse {
impl ToHttpResponse for UploadResponse {
fn into_http_response(self) -> hyper::Response<BoxBody<Bytes, hyper::Error>> {
let delete = "";
println!("-> {}", serde_json::to_string_pretty(&self).unwrap());
hyper::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json; charset=utf-8")
@ -290,6 +332,9 @@ impl ToHttpResponse for UploadResponse {
impl ToHttpResponse for RequestError {
fn into_http_response(self) -> hyper::Response<BoxBody<Bytes, hyper::Error>> {
let delete = "";
println!("-> {}", serde_json::to_string_pretty(&self).unwrap());
hyper::Response::builder()
.status(self.status)
.header(header::CONTENT_TYPE, "application/json; charset=utf-8")

View file

@ -2,6 +2,7 @@ use std::sync::Arc;
use crate::JMAP;
pub mod config;
pub mod http;
pub mod request;
pub mod session;
@ -10,3 +11,11 @@ pub mod session;
pub struct SessionManager {
pub inner: Arc<JMAP>,
}
impl From<JMAP> for SessionManager {
fn from(jmap: JMAP) -> Self {
SessionManager {
inner: Arc::new(jmap),
}
}
}

View file

@ -1,7 +1,7 @@
use jmap_proto::{
error::request::RequestError,
method::{get, query},
request::{Request, RequestMethod},
request::{method::MethodName, Request, RequestMethod},
response::{Response, ResponseMethod},
};
@ -22,7 +22,7 @@ impl JMAP {
for mut call in request.method_calls {
// Resolve result and id references
if let Err(method_error) = response.resolve_references(&mut call.method) {
response.push_response(call.id, method_error);
response.push_response(call.id, MethodName::error(), method_error);
continue;
}
@ -62,7 +62,16 @@ impl JMAP {
RequestMethod::Echo(call) => call.into(),
RequestMethod::Error(error) => error.into(),
};
response.push_response(call.id, method_response);
response.push_response(
call.id,
if !matches!(method_response, ResponseMethod::Error(_)) {
call.name
} else {
MethodName::error()
},
method_response,
);
}
Ok(response)

View file

@ -1,8 +1,11 @@
use jmap_proto::{request::capability::Capability, response::serialize_hex, types::id::Id};
use jmap_proto::{
error::request::RequestError, request::capability::Capability,
response::serialize::serialize_hex, types::id::Id,
};
use store::ahash::AHashSet;
use utils::map::vec_map::VecMap;
use utils::{listener::ServerInstance, map::vec_map::VecMap, UnwrapFailure};
use crate::Config;
use crate::JMAP;
#[derive(Debug, Clone, serde::Serialize)]
pub struct Session {
@ -126,28 +129,42 @@ struct SubmissionCapabilities {
#[derive(Debug, Clone, serde::Serialize)]
struct VacationResponseCapabilities {}
struct BaseCapabilities {
#[derive(Default)]
pub struct BaseCapabilities {
capabilities: VecMap<Capability, Capabilities>,
}
impl BaseCapabilities {
pub fn new(config: &crate::Config, raw_config: &Config) -> Self {
Self {
capabilities: VecMap::from_iter([
(
Capability::Core,
Capabilities::Core(CoreCapabilities::new(config)),
),
(
Capability::Mail,
Capabilities::Mail(MailCapabilities::new(config)),
),
(
Capability::Sieve,
Capabilities::Sieve(SieveCapabilities::new(config, raw_config)),
),
]),
}
impl JMAP {
pub async fn handle_session_resource(
&self,
instance: &ServerInstance,
) -> Result<Session, RequestError> {
let mut session = Session::new(&instance.data, &self.config.capabilities);
session.set_state(0);
session.set_primary_account(
1u64.into(),
"jdoe@example.org".to_string(),
"John Doe".to_string(),
None,
);
Ok(session)
}
}
impl crate::Config {
pub fn add_capabilites(&mut self, settings: &utils::config::Config) {
self.capabilities.capabilities.append(
Capability::Core,
Capabilities::Core(CoreCapabilities::new(self)),
);
self.capabilities.capabilities.append(
Capability::Mail,
Capabilities::Mail(MailCapabilities::new(self)),
);
self.capabilities.capabilities.append(
Capability::Sieve,
Capabilities::Sieve(SieveCapabilities::new(self, settings)),
);
}
}
@ -156,7 +173,7 @@ impl Session {
let mut capabilities = base_capabilities.capabilities.clone();
capabilities.append(
Capability::WebSocket,
Capabilities::WebSocket(WebSocketCapabilities::new(&base_url)),
Capabilities::WebSocket(WebSocketCapabilities::new(base_url)),
);
Session {
@ -190,11 +207,11 @@ impl Session {
if let Some(capabilities) = capabilities {
for capability in capabilities {
self.primary_accounts.append(capability.clone(), account_id);
self.primary_accounts.append(*capability, account_id);
}
} else {
for capability in self.capabilities.keys() {
self.primary_accounts.append(capability.clone(), account_id);
self.primary_accounts.append(*capability, account_id);
}
}
@ -250,7 +267,7 @@ impl Account {
if let Some(capabilities) = capabilities {
for capability in capabilities {
self.account_capabilities.append(
capability.clone(),
*capability,
core_capabilities.get(capability).unwrap().clone(),
);
}
@ -264,13 +281,13 @@ impl Account {
impl CoreCapabilities {
pub fn new(config: &crate::Config) -> Self {
CoreCapabilities {
max_size_upload: config.max_size_upload,
max_concurrent_upload: config.max_concurrent_uploads,
max_size_request: config.max_size_request,
max_concurrent_requests: config.max_concurrent_requests,
max_calls_in_request: config.max_calls_in_request,
max_objects_in_get: config.max_objects_in_get,
max_objects_in_set: config.max_objects_in_set,
max_size_upload: config.upload_max_size,
max_concurrent_upload: config.upload_max_concurrent,
max_size_request: config.request_max_size,
max_concurrent_requests: config.request_max_concurrent as usize,
max_calls_in_request: config.request_max_calls,
max_objects_in_get: config.get_max_objects,
max_objects_in_set: config.set_max_objects,
collation_algorithms: vec![
"i;ascii-numeric".to_string(),
"i;ascii-casemap".to_string(),
@ -290,25 +307,23 @@ impl WebSocketCapabilities {
}
impl SieveCapabilities {
pub fn new(config: &crate::Config, raw_config: &Config) -> Self {
pub fn new(config: &crate::Config, settings: &utils::config::Config) -> Self {
let mut notification_methods = Vec::new();
for part in settings
.get("sieve-notification-uris")
.unwrap_or_else(|| "mailto".to_string())
.split_ascii_whitespace()
{
if !part.is_empty() {
notification_methods.push(part.to_string());
}
for (_, uri) in settings.values("jmap.sieve.notification-uris") {
notification_methods.push(uri.to_string());
}
if notification_methods.is_empty() {
notification_methods.push("mailto".to_string());
}
let mut capabilities: AHashSet<Capability> =
AHashSet::from_iter(Capability::all().iter().cloned());
if let Some(disable) = settings.get("sieve-disable-capabilities") {
for item in disable.split_ascii_whitespace() {
capabilities.remove(&Capability::parse(item));
}
let mut capabilities: AHashSet<sieve::compiler::grammar::Capability> =
AHashSet::from_iter(sieve::compiler::grammar::Capability::all().iter().cloned());
for (_, capability) in settings.values("jmap.sieve.disabled-capabilities") {
capabilities.remove(&sieve::compiler::grammar::Capability::parse(capability));
}
let mut extensions = capabilities
.into_iter()
.map(|c| c.to_string())
@ -318,10 +333,14 @@ impl SieveCapabilities {
SieveCapabilities {
max_script_name: config.sieve_max_script_name,
max_script_size: settings
.parse("sieve-max-script-size")
.property("jmap.sieve.max-script-size")
.failed("Invalid configuration file")
.unwrap_or(1024 * 1024),
max_scripts: config.sieve_max_scripts,
max_redirects: settings.parse("sieve-max-redirects").unwrap_or(1),
max_redirects: settings
.property("jmap.sieve.max-redirects")
.failed("Invalid configuration file")
.unwrap_or(1),
extensions,
notification_methods: if !notification_methods.is_empty() {
notification_methods.into()

View file

@ -1,8 +1,11 @@
use jmap_proto::types::blob::BlobId;
use std::ops::Range;
use jmap_proto::{error::method::MethodError, types::blob::BlobId};
use mail_parser::{
decoders::{base64::base64_decode, quoted_printable::quoted_printable_decode},
Encoding,
};
use store::BlobKind;
use crate::JMAP;
@ -36,4 +39,22 @@ impl JMAP {
self.store.get_blob(&blob_id.kind, 0..u32::MAX).await
}
}
pub async fn get_blob(
&self,
kind: &BlobKind,
range: Range<u32>,
) -> Result<Option<Vec<u8>>, MethodError> {
match self.store.get_blob(kind, range).await {
Ok(blob) => Ok(blob),
Err(err) => {
tracing::error!(event = "error",
context = "blob_store",
blob_id = ?kind,
error = ?err,
"Failed to retrieve blob");
Err(MethodError::ServerPartialFail)
}
}
}
}

View file

@ -16,16 +16,23 @@ impl JMAP {
) -> Result<UploadResponse, RequestError> {
let blob_id = BlobId::temporary(account_id.document_id());
self.store
.put_blob(&blob_id.kind, data)
.await
.map_err(|err| RequestError::internal_server_error())?;
Ok(UploadResponse {
account_id,
blob_id,
c_type: content_type.to_string(),
size: data.len(),
})
match self.store.put_blob(&blob_id.kind, data).await {
Ok(_) => Ok(UploadResponse {
account_id,
blob_id,
c_type: content_type.to_string(),
size: data.len(),
}),
Err(err) => {
tracing::error!(event = "error",
context = "blob_store",
account_id = account_id.document_id(),
blob_id = ?blob_id,
size = data.len(),
error = ?err,
"Failed to upload blob");
Err(RequestError::internal_server_error())
}
}
}
}

View file

@ -0,0 +1,27 @@
use jmap_proto::{
error::method::MethodError,
types::{collection::Collection, state::State},
};
use crate::JMAP;
impl JMAP {
pub async fn get_state(
&self,
account_id: u32,
collection: Collection,
) -> Result<State, MethodError> {
match self.store.get_last_change_id(account_id, collection).await {
Ok(id) => Ok(id.into()),
Err(err) => {
tracing::error!(event = "error",
context = "store",
account_id = account_id,
collection = ?collection,
error = ?err,
"Failed to obtain state");
Err(MethodError::ServerPartialFail)
}
}
}
}

View file

@ -2,10 +2,12 @@ use jmap_proto::{
error::method::MethodError,
method::get::{GetRequest, GetResponse},
object::{email::GetArguments, Object},
types::{blob::BlobId, collection::Collection, property::Property, value::Value},
types::{
blob::BlobId, collection::Collection, id::Id, keyword::Keyword, property::Property,
value::Value,
},
};
use mail_parser::Message;
use store::ValueKey;
use crate::{email::headers::HeaderToValue, JMAP};
@ -72,11 +74,7 @@ impl JMAP {
let account_id = request.account_id.document_id();
let mut response = GetResponse {
account_id: Some(request.account_id),
state: self
.store
.get_last_change_id(account_id, Collection::Email)
.await?
.into(),
state: self.get_state(account_id, Collection::Email).await?,
list: Vec::with_capacity(ids.len()),
not_found: vec![],
};
@ -106,20 +104,20 @@ impl JMAP {
for id in ids {
// Obtain the email object
let mut values = if let Some(value) = self
.store
.get_value::<Object<Value>>(ValueKey::new(
let mut values = match self
.get_property::<Object<Value>>(
account_id,
Collection::Email,
id.document_id(),
Property::BodyStructure,
))
&Property::BodyStructure,
)
.await?
{
value
} else {
response.not_found.push(id);
continue;
Some(values) => values,
None => {
response.not_found.push(id);
continue;
}
};
// Retrieve raw message if needed
@ -135,10 +133,15 @@ impl JMAP {
u32::MAX
};
if let Some(raw_message) = self.store.get_blob(&blob_id.kind, 0..offset).await? {
if let Some(raw_message) = self.get_blob(&blob_id.kind, 0..offset).await? {
raw_message
} else {
let log = "true";
tracing::warn!(event = "not-found",
account_id = account_id,
collection = ?Collection::Email,
document_id = id.document_id(),
blob_id = ?blob_id,
"Blob not found");
response.not_found.push(id);
continue;
}
@ -148,7 +151,13 @@ impl JMAP {
let message = if !raw_message.is_empty() {
let message = Message::parse(&raw_message);
if message.is_none() {
let log = "true";
tracing::warn!(
event = "parse-error",
account_id = account_id,
collection = ?Collection::Email,
document_id = id.document_id(),
blob_id = ?blob_id,
"Failed to parse stored message");
}
message
} else {
@ -160,26 +169,52 @@ impl JMAP {
for property in &properties {
match property {
Property::Id => {
email.append(Property::Id, *id);
email.append(Property::Id, Id::from(*id));
}
Property::ThreadId => {
email.append(Property::ThreadId, id.prefix_id());
email.append(Property::ThreadId, Id::from(id.prefix_id()));
}
Property::BlobId => {
email.append(Property::BlobId, blob_id.clone());
}
Property::MailboxIds | Property::Keywords => {
Property::MailboxIds => {
email.append(
property.clone(),
self.store
.get_value::<Value>(ValueKey::new(
account_id,
Collection::Email,
id.document_id(),
property.clone(),
))
.await?
.unwrap_or(Value::Null),
self.get_property::<Vec<u32>>(
account_id,
Collection::Email,
id.document_id(),
&Property::MailboxIds,
)
.await?
.map(|ids| {
let mut obj = Object::with_capacity(ids.len());
for id in ids {
obj.append(Property::_T(Id::from(id).to_string()), true);
}
Value::Object(obj)
})
.unwrap_or(Value::Null),
);
}
Property::Keywords => {
email.append(
property.clone(),
self.get_property::<Vec<Keyword>>(
account_id,
Collection::Email,
id.document_id(),
&Property::Keywords,
)
.await?
.map(|keywords| {
let mut obj = Object::with_capacity(keywords.len());
for keyword in keywords {
obj.append(Property::_T(keyword.to_string()), true);
}
Value::Object(obj)
})
.unwrap_or(Value::Null),
);
}
Property::Size

View file

@ -6,7 +6,6 @@ use jmap_proto::{
method::import::{ImportEmailRequest, ImportEmailResponse},
types::{collection::Collection, property::Property, state::State},
};
use store::BitmapKey;
use utils::map::vec_map::VecMap;
use crate::{MaybeError, JMAP};
@ -18,11 +17,7 @@ impl JMAP {
) -> Result<ImportEmailResponse, MethodError> {
// Validate state
let account_id = request.account_id.document_id();
let old_state: State = self
.store
.get_last_change_id(account_id, Collection::Email)
.await?
.into();
let old_state: State = self.get_state(account_id, Collection::Email).await?;
if let Some(if_in_state) = request.if_in_state {
if old_state != if_in_state {
return Err(MethodError::StateMismatch);
@ -31,8 +26,7 @@ impl JMAP {
let cococ = "implement ACLS";
let valid_mailbox_ids = self
.store
.get_bitmap(BitmapKey::document_ids(account_id, Collection::Mailbox))
.get_document_ids(account_id, Collection::Mailbox)
.await?
.unwrap_or_default();
@ -56,7 +50,8 @@ impl JMAP {
);
continue;
}
for mailbox_id in &mailbox_ids {
let enable = "true";
/*for mailbox_id in &mailbox_ids {
if !valid_mailbox_ids.contains(*mailbox_id) {
not_created.append(
id,
@ -66,20 +61,29 @@ impl JMAP {
);
continue 'outer;
}
}
}*/
// Fetch raw message to import
let raw_message =
if let Some(raw_message) = self.blob_download(&email.blob_id, account_id).await? {
raw_message
} else {
let raw_message = match self.blob_download(&email.blob_id, account_id).await {
Ok(Some(raw_message)) => raw_message,
Ok(None) => {
not_created.append(
id,
SetError::new(SetErrorType::BlobNotFound)
.with_description(format!("BlobId {} not found.", email.blob_id)),
);
continue;
};
}
Err(err) => {
tracing::error!(event = "error",
context = "store",
account_id = account_id,
blob_id = ?email.blob_id,
error = ?err,
"Failed to retrieve blob");
return Err(MethodError::ServerPartialFail);
}
};
// Import message
match self
@ -101,7 +105,7 @@ impl JMAP {
SetError::new(SetErrorType::InvalidEmail).with_description(reason),
);
}
Err(MaybeError::Temporary(_)) => {
Err(MaybeError::Temporary) => {
return Err(MethodError::ServerPartialFail);
}
}
@ -110,10 +114,7 @@ impl JMAP {
Ok(ImportEmailResponse {
account_id: request.account_id,
new_state: if !created.is_empty() {
self.store
.get_last_change_id(account_id, Collection::Email)
.await?
.into()
self.get_state(account_id, Collection::Email).await?
} else {
old_state.clone()
},

View file

@ -4,6 +4,7 @@ use jmap_proto::{
object::Object,
types::{
date::UTCDate,
id::Id,
keyword::Keyword,
property::{HeaderForm, Property},
value::Value,
@ -50,18 +51,10 @@ impl IndexMessage for BatchBuilder {
let mut object = Object::with_capacity(15);
// Index keywords
self.value(
Property::Keywords,
Value::from(keywords),
F_VALUE | F_BITMAP,
);
self.value(Property::Keywords, keywords, F_VALUE | F_BITMAP);
// Index mailboxIds
self.value(
Property::MailboxIds,
Value::from(mailbox_ids),
F_VALUE | F_BITMAP,
);
self.value(Property::MailboxIds, mailbox_ids, F_VALUE | F_BITMAP);
// Index size
object.append(Property::Size, message.raw_message.len());
@ -348,7 +341,7 @@ impl IndexMessage for BatchBuilder {
}
// Store properties
self.value(Property::BodyStructure, Value::from(object), F_VALUE);
self.value(Property::BodyStructure, object, F_VALUE);
// Store full text index
self.custom(fts)?;

View file

@ -92,15 +92,41 @@ impl JMAP {
let document_id = self
.store
.assign_document_id(account_id, Collection::Email)
.await?;
.await
.map_err(|err| {
tracing::error!(
event = "error",
context = "email_ingest",
error = ?err,
"Failed to assign documentId.");
MaybeError::Temporary
})?;
let change_id = self
.store
.assign_change_id(account_id, Collection::Email)
.await?;
.await
.map_err(|err| {
tracing::error!(
event = "error",
context = "email_ingest",
error = ?err,
"Failed to assign changeId.");
MaybeError::Temporary
})?;
// Store blob
let blob_id = BlobId::maildir(account_id, document_id);
self.store.put_blob(&blob_id.kind, raw_message).await?;
self.store
.put_blob(&blob_id.kind, raw_message)
.await
.map_err(|err| {
tracing::error!(
event = "error",
context = "email_ingest",
error = ?err,
"Failed to write blob.");
MaybeError::Temporary
})?;
// Build change log
let mut changes = ChangeLogBuilder::with_change_id(change_id);
@ -111,7 +137,15 @@ impl JMAP {
let thread_id = self
.store
.assign_document_id(account_id, Collection::Thread)
.await?;
.await
.map_err(|err| {
tracing::error!(
event = "error",
context = "email_ingest",
error = ?err,
"Failed to assign documentId for new thread.");
MaybeError::Temporary
})?;
changes.log_insert(Collection::Thread, thread_id);
thread_id
};
@ -123,16 +157,42 @@ impl JMAP {
// Build write batch
let mut batch = BatchBuilder::new();
batch.index_message(
message,
keywords,
mailbox_ids,
received_at.unwrap_or_else(now),
self.config.default_language,
)?;
batch
.with_account_id(account_id)
.with_collection(Collection::Email)
.create_document(document_id)
.index_message(
message,
keywords,
mailbox_ids,
received_at.unwrap_or_else(now),
self.config.default_language,
)
.map_err(|err| {
tracing::error!(
event = "error",
context = "email_ingest",
error = ?err,
"Failed to index message.");
MaybeError::Temporary
})?;
batch.value(Property::ThreadId, thread_id, F_VALUE | F_BITMAP);
batch.custom(changes)?;
self.store.write(batch.build()).await?;
batch.custom(changes).map_err(|err| {
tracing::error!(
event = "error",
context = "email_ingest",
error = ?err,
"Failed to add changelog to write batch.");
MaybeError::Temporary
})?;
self.store.write(batch.build()).await.map_err(|err| {
tracing::error!(
event = "error",
context = "email_ingest",
error = ?err,
"Failed to write message to database.");
MaybeError::Temporary
})?;
Ok(IngestedEmail {
id,
@ -162,7 +222,15 @@ impl JMAP {
let results = self
.store
.filter(account_id, Collection::Email, filters)
.await?
.await
.map_err(|err| {
tracing::error!(
event = "error",
context = "find_or_merge_thread",
error = ?err,
"Thread search failed.");
MaybeError::Temporary
})?
.results;
if results.is_empty() {
return Ok(None);
@ -184,7 +252,15 @@ impl JMAP {
})
.collect(),
)
.await?;
.await
.map_err(|err| {
tracing::error!(
event = "error",
context = "find_or_merge_thread",
error = ?err,
"Failed to obtain threadIds.");
MaybeError::Temporary
})?;
if thread_ids.len() == 1 {
return Ok(thread_ids.into_iter().next().unwrap());
}
@ -212,7 +288,15 @@ impl JMAP {
let change_id = self
.store
.assign_change_id(account_id, Collection::Thread)
.await?;
.await
.map_err(|err| {
tracing::error!(
event = "error",
context = "find_or_merge_thread",
error = ?err,
"Failed to assign changeId for thread merge.");
MaybeError::Temporary
})?;
let mut changes = ChangeLogBuilder::with_change_id(change_id);
batch
.with_account_id(account_id)
@ -241,14 +325,28 @@ impl JMAP {
)
}
}
batch.custom(changes)?;
batch.custom(changes).map_err(|err| {
tracing::error!(
event = "error",
context = "find_or_merge_thread",
error = ?err,
"Failed to add changelog to write batch.");
MaybeError::Temporary
})?;
match self.store.write(batch.build()).await {
Ok(_) => return Ok(Some(thread_id)),
Err(store::Error::AssertValueFailed) if try_count < 3 => {
try_count += 1;
}
Err(err) => return Err(err.into()),
Err(err) => {
tracing::error!(
event = "error",
context = "find_or_merge_thread",
error = ?err,
"Failed to write thread merge batch.");
return Err(MaybeError::Temporary);
}
}
}
}
@ -258,7 +356,8 @@ impl From<IngestedEmail> for Object<Value> {
fn from(email: IngestedEmail) -> Self {
Object::with_capacity(3)
.with_property(Property::Id, email.id)
.with_property(Property::ThreadId, email.id.prefix_id())
.with_property(Property::ThreadId, Id::from(email.id.prefix_id()))
.with_property(Property::BlobId, email.blob_id)
.with_property(Property::Size, email.size)
}
}

View file

@ -8,7 +8,7 @@ use store::{
fts::Language,
query::{self, sort::Pagination},
roaring::RoaringBitmap,
BitmapKey, ValueKey,
ValueKey,
};
use crate::JMAP;
@ -23,14 +23,18 @@ impl JMAP {
for cond in request.filter {
match cond {
Filter::InMailbox(mailbox) => {
filters.push(query::Filter::is_in_bitmap(Property::MailboxIds, mailbox))
}
Filter::InMailbox(mailbox) => filters.push(query::Filter::is_in_bitmap(
Property::MailboxIds,
mailbox.document_id(),
)),
Filter::InMailboxOtherThan(mailboxes) => {
filters.push(query::Filter::Not);
filters.push(query::Filter::Or);
for mailbox in mailboxes {
filters.push(query::Filter::is_in_bitmap(Property::MailboxIds, mailbox));
filters.push(query::Filter::is_in_bitmap(
Property::MailboxIds,
mailbox.document_id(),
));
}
filters.push(query::Filter::End);
filters.push(query::Filter::End);
@ -141,18 +145,31 @@ impl JMAP {
}
Filter::SentBefore(date) => filters.push(query::Filter::lt(Property::SentAt, date)),
Filter::SentAfter(date) => filters.push(query::Filter::gt(Property::SentAt, date)),
Filter::InThread(id) => {
filters.push(query::Filter::is_in_bitmap(Property::ThreadId, id))
}
Filter::InThread(id) => filters.push(query::Filter::is_in_bitmap(
Property::ThreadId,
id.document_id(),
)),
other => return Err(MethodError::UnsupportedFilter(other.to_string())),
}
}
let result_set = self
let result_set = match self
.store
.filter(account_id, Collection::Email, filters)
.await?;
.await
{
Ok(result_set) => result_set,
Err(err) => {
tracing::error!(event = "error",
context = "store",
account_id = account_id,
collection = "email",
error = ?err,
"Filter failed");
return Err(MethodError::ServerPartialFail);
}
};
let total = result_set.results.len() as usize;
let (limit_total, limit) = if let Some(limit) = request.limit {
if limit > 0 {
@ -169,11 +186,7 @@ impl JMAP {
};
let mut response = QueryResponse {
account_id: request.account_id,
query_state: self
.store
.get_last_change_id(account_id, Collection::Email)
.await?
.into(),
query_state: self.get_state(account_id, Collection::Email).await?,
can_calculate_changes: true,
position: 0,
ids: vec![],
@ -213,15 +226,14 @@ impl JMAP {
query::Comparator::field(Property::SentAt, comparator.is_ascending)
}
SortProperty::HasKeyword => query::Comparator::set(
self.store
.get_bitmap(BitmapKey::value(
account_id,
Collection::Email,
Property::Keywords,
comparator.keyword.unwrap_or(Keyword::Seen),
))
.await?
.unwrap_or_default(),
self.get_tag(
account_id,
Collection::Email,
Property::Keywords,
comparator.keyword.unwrap_or(Keyword::Seen),
)
.await?
.unwrap_or_default(),
comparator.is_ascending,
),
SortProperty::AllInThreadHaveKeyword => query::Comparator::set(
@ -252,7 +264,7 @@ impl JMAP {
}
// Sort results
let result = self
let result = match self
.store
.sort(
result_set,
@ -266,7 +278,19 @@ impl JMAP {
request.arguments.collapse_threads.unwrap_or(false),
),
)
.await?;
.await
{
Ok(result) => result,
Err(err) => {
tracing::error!(event = "error",
context = "store",
account_id = account_id,
collection = "email",
error = ?err,
"Sort failed");
return Err(MethodError::ServerPartialFail);
}
};
// Prepare response
if result.found_anchor {
@ -291,13 +315,7 @@ impl JMAP {
match_all: bool,
) -> Result<RoaringBitmap, MethodError> {
let keyword_doc_ids = self
.store
.get_bitmap(BitmapKey::value(
account_id,
Collection::Email,
Property::Keywords,
keyword,
))
.get_tag(account_id, Collection::Email, Property::Keywords, keyword)
.await?
.unwrap_or_default();
@ -309,23 +327,16 @@ impl JMAP {
continue;
}
if let Some(thread_id) = self
.store
.get_value::<u32>(ValueKey::new(
.get_property::<u32>(
account_id,
Collection::Email,
keyword_doc_id,
Property::ThreadId,
))
&Property::ThreadId,
)
.await?
{
if let Some(thread_doc_ids) = self
.store
.get_bitmap(BitmapKey::value(
account_id,
Collection::Email,
Property::ThreadId,
thread_id,
))
.get_tag(account_id, Collection::Email, Property::ThreadId, thread_id)
.await?
{
let mut thread_tag_intersection = thread_doc_ids.clone();

View file

@ -1,8 +1,17 @@
use jmap_proto::error::method::MethodError;
use store::{fts::Language, Store};
use api::session::BaseCapabilities;
use jmap_proto::{
error::method::MethodError,
types::{collection::Collection, property::Property},
};
use store::{
fts::Language, roaring::RoaringBitmap, write::BitmapFamily, BitmapKey, Deserialize, Serialize,
Store, ValueKey,
};
use utils::UnwrapFailure;
pub mod api;
pub mod blob;
pub mod changes;
pub mod email;
pub struct JMAP {
@ -13,38 +22,118 @@ pub struct JMAP {
pub struct Config {
pub default_language: Language,
pub query_max_results: usize,
pub request_max_size: usize,
pub request_max_calls: usize,
pub request_max_concurrent: u64,
pub request_max_concurrent_total: u64,
pub get_max_objects: usize,
pub set_max_objects: usize,
pub upload_max_size: usize,
pub upload_max_concurrent: usize,
pub mailbox_max_depth: usize,
pub mailbox_name_max_len: usize,
pub mail_attachments_max_size: usize,
pub sieve_max_script_name: usize,
pub sieve_max_scripts: usize,
pub capabilities: BaseCapabilities,
}
pub enum MaybeError {
Temporary(String),
Temporary,
Permanent(String),
}
impl From<store::Error> for MaybeError {
fn from(e: store::Error) -> Self {
match e {
store::Error::InternalError(msg) => {
let log = "true";
MaybeError::Temporary(format!("Database error: {msg}"))
}
store::Error::AssertValueFailed => {
MaybeError::Permanent("Assert value failed".to_string())
}
impl JMAP {
pub async fn new(config: &utils::config::Config) -> Self {
JMAP {
store: Store::open(config).await.failed("Unable to open database"),
config: Config::new(config).failed("Invalid configuration file"),
}
}
}
impl From<MaybeError> for MethodError {
fn from(value: MaybeError) -> Self {
match value {
MaybeError::Temporary(msg) => {
let log = "true";
MethodError::ServerPartialFail
pub async fn get_property<U>(
&self,
account_id: u32,
collection: Collection,
document_id: u32,
property: &Property,
) -> Result<Option<U>, MethodError>
where
U: Deserialize + 'static,
{
match self
.store
.get_value::<U>(ValueKey::new(account_id, collection, document_id, property))
.await
{
Ok(value) => Ok(value),
Err(err) => {
tracing::error!(event = "error",
context = "store",
account_id = account_id,
collection = ?collection,
document_id = document_id,
property = ?property,
error = ?err,
"Failed to retrieve property");
Err(MethodError::ServerPartialFail)
}
}
}
pub async fn get_document_ids(
&self,
account_id: u32,
collection: Collection,
) -> Result<Option<RoaringBitmap>, MethodError> {
match self
.store
.get_bitmap(BitmapKey::document_ids(account_id, collection))
.await
{
Ok(value) => Ok(value),
Err(err) => {
tracing::error!(event = "error",
context = "store",
account_id = account_id,
collection = ?collection,
error = ?err,
"Failed to retrieve document ids bitmap");
Err(MethodError::ServerPartialFail)
}
}
}
pub async fn get_tag(
&self,
account_id: u32,
collection: Collection,
property: impl AsRef<Property>,
value: impl BitmapFamily + Serialize,
) -> Result<Option<RoaringBitmap>, MethodError> {
let property = property.as_ref();
match self
.store
.get_bitmap(BitmapKey::value(account_id, collection, property, value))
.await
{
Ok(value) => Ok(value),
Err(err) => {
tracing::error!(event = "error",
context = "store",
account_id = account_id,
collection = ?collection,
property = ?property,
error = ?err,
"Failed to retrieve tag bitmap");
Err(MethodError::ServerPartialFail)
}
MaybeError::Permanent(msg) => MethodError::InvalidArguments(msg),
}
}
}

1
crates/main/src/main.rs Normal file
View file

@ -0,0 +1 @@
fn main() {}

View file

@ -0,0 +1,26 @@
[package]
name = "maybe-async"
version = "0.2.7"
authors = [ "Guoli Lyu <guoli-lv@hotmail.com>" ]
edition = "2018"
readme = "README.md"
license = "MIT"
description = "A procedure macro to unify SYNC and ASYNC implementation"
repository = "https://github.com/fMeow/maybe-async-rs"
documentation = "https://docs.rs/maybe-async"
keywords = [ "maybe", "async", "futures", "macros", "proc_macro" ]
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
[dependencies.syn]
version = "1.0"
features = [ "visit-mut", "full" ]
[lib]
proc-macro = true
[features]
default = [ ]
is_sync = [ ]

View file

@ -0,0 +1,619 @@
//!
//! # Maybe-Async Procedure Macro
//!
//! **Why bother writing similar code twice for blocking and async code?**
//!
//! [![Build Status](https://github.com/fMeow/maybe-async-rs/workflows/CI%20%28Linux%29/badge.svg?branch=main)](https://github.com/fMeow/maybe-async-rs/actions)
//! [![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE)
//! [![Latest Version](https://img.shields.io/crates/v/maybe-async.svg)](https://crates.io/crates/maybe-async)
//! [![maybe-async](https://docs.rs/maybe-async/badge.svg)](https://docs.rs/maybe-async)
//!
//! When implementing both sync and async versions of API in a crate, most API
//! of the two version are almost the same except for some async/await keyword.
//!
//! `maybe-async` help unifying async and sync implementation by **procedural
//! macro**.
//! - Write async code with normal `async`, `await`, and let `maybe_async`
//! handles
//! those `async` and `await` when you need a blocking code.
//! - Switch between sync and async by toggling `is_sync` feature gate in
//! `Cargo.toml`.
//! - use `must_be_async` and `must_be_sync` to keep code in specified version
//! - use `impl_async` and `impl_sync` to only compile code block on specified
//! version
//! - A handy macro to unify unit test code is also provided.
//!
//! These procedural macros can be applied to the following codes:
//! - trait item declaration
//! - trait implmentation
//! - function definition
//! - struct definition
//!
//! **RECOMMENDATION**: Enable **resolver ver2** in your crate, which is
//! introduced in Rust 1.51. If not, two crates in dependency with conflict
//! version (one async and another blocking) can fail complilation.
//!
//!
//! ## Motivation
//!
//! The async/await language feature alters the async world of rust.
//! Comparing with the map/and_then style, now the async code really resembles
//! sync version code.
//!
//! In many crates, the async and sync version of crates shares the same API,
//! but the minor difference that all async code must be awaited prevent the
//! unification of async and sync code. In other words, we are forced to write
//! an async and an sync implementation repectively.
//!
//! ## Macros in Detail
//!
//! `maybe-async` offers 4 set of attribute macros: `maybe_async`,
//! `sync_impl`/`async_impl`, `must_be_sync`/`must_be_async`, and `test`.
//!
//! To use `maybe-async`, we must know which block of codes is only used on
//! blocking implementation, and which on async. These two implementation should
//! share the same function signatures except for async/await keywords, and use
//! `sync_impl` and `async_impl` to mark these implementation.
//!
//! Use `maybe_async` macro on codes that share the same API on both async and
//! blocking code except for async/await keywords. And use feature gate
//! `is_sync` in `Cargo.toml` to toggle between async and blocking code.
//!
//! - `maybe_async`
//!
//! Offers a unified feature gate to provide sync and async conversion on
//! demand by feature gate `is_sync`, with **async first** policy.
//!
//! Want to keep async code? add `maybe_async` in dependencies with default
//! features, which means `maybe_async` is the same as `must_be_async`:
//!
//! ```toml
//! [dependencies]
//! maybe_async = "0.2"
//! ```
//!
//! Wanna convert async code to sync? Add `maybe_async` to dependencies with
//! an `is_sync` feature gate. In this way, `maybe_async` is the same as
//! `must_be_sync`:
//!
//! ```toml
//! [dependencies]
//! maybe_async = { version = "0.2", features = ["is_sync"] }
//! ```
//!
//! Not all async traits need futures that are `dyn Future + Send`.
//! To avoid having "Send" and "Sync" bounds placed on the async trait
//! methods, invoke the maybe_async macro as #[maybe_async(?Send)] on both
//! the trait and the impl blocks.
//!
//!
//! - `must_be_async`
//!
//! **Keep async**. Add `async_trait` attribute macro for trait declaration
//! or implementation to bring async fn support in traits.
//!
//! To avoid having "Send" and "Sync" bounds placed on the async trait
//! methods, invoke the maybe_async macro as #[must_be_async(?Send)].
//!
//! - `must_be_sync`
//!
//! **Convert to sync code**. Convert the async code into sync code by
//! removing all `async move`, `async` and `await` keyword
//!
//!
//! - `sync_impl`
//!
//! An sync implementation should on compile on blocking implementation and
//! must simply disappear when we want async version.
//!
//! Although most of the API are almost the same, there definitely come to a
//! point when the async and sync version should differ greatly. For
//! example, a MongoDB client may use the same API for async and sync
//! verison, but the code to actually send reqeust are quite different.
//!
//! Here, we can use `sync_impl` to mark a synchronous implementation, and a
//! sync implementation shoule disappear when we want async version.
//!
//! - `async_impl`
//!
//! An async implementation should on compile on async implementation and
//! must simply disappear when we want sync version.
//!
//! To avoid having "Send" and "Sync" bounds placed on the async trait
//! methods, invoke the maybe_async macro as #[async_impl(?Send)].
//!
//!
//! - `test`
//!
//! Handy macro to unify async and sync **unit and e2e test** code.
//!
//! You can specify the condition to compile to sync test code
//! and also the conditions to compile to async test code with given test
//! macro, e.x. `tokio::test`, `async_std::test` and etc. When only sync
//! condition is specified,the test code only compiles when sync condition
//! is met.
//!
//! ```rust
//! # #[maybe_async::maybe_async]
//! # async fn async_fn() -> bool {
//! # true
//! # }
//!
//! ##[maybe_async::test(
//! feature="is_sync",
//! async(
//! all(not(feature="is_sync"), feature="async_std"),
//! async_std::test
//! ),
//! async(
//! all(not(feature="is_sync"), feature="tokio"),
//! tokio::test
//! )
//! )]
//! async fn test_async_fn() {
//! let res = async_fn().await;
//! assert_eq!(res, true);
//! }
//! ```
//!
//! ## What's Under the Hook
//!
//! `maybe-async` compiles your code in different way with the `is_sync` feature
//! gate. It remove all `await` and `async` keywords in your code under
//! `maybe_async` macro and conditionally compiles codes under `async_impl` and
//! `sync_impl`.
//!
//! Here is an detailed example on what's going on whe the `is_sync` feature
//! gate set or not.
//!
//! ```rust
//! #[maybe_async::maybe_async(?Send)]
//! trait A {
//! async fn async_fn_name() -> Result<(), ()> {
//! Ok(())
//! }
//! fn sync_fn_name() -> Result<(), ()> {
//! Ok(())
//! }
//! }
//!
//! struct Foo;
//!
//! #[maybe_async::maybe_async(?Send)]
//! impl A for Foo {
//! async fn async_fn_name() -> Result<(), ()> {
//! Ok(())
//! }
//! fn sync_fn_name() -> Result<(), ()> {
//! Ok(())
//! }
//! }
//!
//! #[maybe_async::maybe_async]
//! async fn maybe_async_fn() -> Result<(), ()> {
//! let a = Foo::async_fn_name().await?;
//!
//! let b = Foo::sync_fn_name()?;
//! Ok(())
//! }
//! ```
//!
//! When `maybe-async` feature gate `is_sync` is **NOT** set, the generated code
//! is async code:
//!
//! ```rust
//! // Compiled code when `is_sync` is toggled off.
//! #[async_trait::async_trait(?Send)]
//! trait A {
//! async fn maybe_async_fn_name() -> Result<(), ()> {
//! Ok(())
//! }
//! fn sync_fn_name() -> Result<(), ()> {
//! Ok(())
//! }
//! }
//!
//! struct Foo;
//!
//! #[async_trait::async_trait(?Send)]
//! impl A for Foo {
//! async fn maybe_async_fn_name() -> Result<(), ()> {
//! Ok(())
//! }
//! fn sync_fn_name() -> Result<(), ()> {
//! Ok(())
//! }
//! }
//!
//! async fn maybe_async_fn() -> Result<(), ()> {
//! let a = Foo::maybe_async_fn_name().await?;
//! let b = Foo::sync_fn_name()?;
//! Ok(())
//! }
//! ```
//!
//! When `maybe-async` feature gate `is_sync` is set, all async keyword is
//! ignored and yields a sync version code:
//!
//! ```rust
//! // Compiled code when `is_sync` is toggled on.
//! trait A {
//! fn maybe_async_fn_name() -> Result<(), ()> {
//! Ok(())
//! }
//! fn sync_fn_name() -> Result<(), ()> {
//! Ok(())
//! }
//! }
//!
//! struct Foo;
//!
//! impl A for Foo {
//! fn maybe_async_fn_name() -> Result<(), ()> {
//! Ok(())
//! }
//! fn sync_fn_name() -> Result<(), ()> {
//! Ok(())
//! }
//! }
//!
//! fn maybe_async_fn() -> Result<(), ()> {
//! let a = Foo::maybe_async_fn_name()?;
//! let b = Foo::sync_fn_name()?;
//! Ok(())
//! }
//! ```
//!
//! ## Examples
//!
//! ### rust client for services
//!
//! When implementing rust client for any services, like awz3. The higher level
//! API of async and sync version is almost the same, such as creating or
//! deleting a bucket, retrieving an object and etc.
//!
//! The example `service_client` is a proof of concept that `maybe_async` can
//! actually free us from writing almost the same code for sync and async. We
//! can toggle between a sync AWZ3 client and async one by `is_sync` feature
//! gate when we add `maybe-async` to dependency.
//!
//!
//! # License
//! MIT
extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use syn::{
parse_macro_input, spanned::Spanned, AttributeArgs, ImplItem, Lit, Meta, NestedMeta, TraitItem,
};
use quote::quote;
use crate::{parse::Item, visit::AsyncAwaitRemoval};
mod parse;
mod visit;
fn convert_async(input: &mut Item, send: bool) -> TokenStream2 {
if send {
match input {
Item::Impl(item) => quote!(#[async_trait::async_trait]#item),
Item::Trait(item) => quote!(#[async_trait::async_trait]#item),
Item::Fn(item) => quote!(#item),
Item::Static(item) => quote!(#item),
}
} else {
match input {
Item::Impl(item) => quote!(#[async_trait::async_trait(?Send)]#item),
Item::Trait(item) => quote!(#[async_trait::async_trait(?Send)]#item),
Item::Fn(item) => quote!(#item),
Item::Static(item) => quote!(#item),
}
}
}
fn convert_sync(input: &mut Item) -> TokenStream2 {
match input {
Item::Impl(item) => {
for inner in &mut item.items {
if let ImplItem::Method(ref mut method) = inner {
if method.sig.asyncness.is_some() {
method.sig.asyncness = None;
}
}
}
AsyncAwaitRemoval.remove_async_await(quote!(#item))
}
Item::Trait(item) => {
for inner in &mut item.items {
if let TraitItem::Method(ref mut method) = inner {
if method.sig.asyncness.is_some() {
method.sig.asyncness = None;
}
}
}
AsyncAwaitRemoval.remove_async_await(quote!(#item))
}
Item::Fn(item) => {
if item.sig.asyncness.is_some() {
item.sig.asyncness = None;
}
AsyncAwaitRemoval.remove_async_await(quote!(#item))
}
Item::Static(item) => AsyncAwaitRemoval.remove_async_await(quote!(#item)),
}
}
/// maybe_async attribute macro
///
/// Can be applied to trait item, trait impl, functions and struct impls.
#[proc_macro_attribute]
pub fn maybe_async(args: TokenStream, input: TokenStream) -> TokenStream {
let send = match args.to_string().replace(' ', "").as_str() {
"" | "Send" => true,
"?Send" => false,
_ => {
return syn::Error::new(Span::call_site(), "Only accepts `Send` or `?Send`")
.to_compile_error()
.into();
}
};
let mut item = parse_macro_input!(input as Item);
let token = if cfg!(feature = "is_sync") {
convert_sync(&mut item)
} else {
convert_async(&mut item, send)
};
token.into()
}
/// convert marked async code to async code with `async-trait`
#[proc_macro_attribute]
pub fn must_be_async(args: TokenStream, input: TokenStream) -> TokenStream {
let send = match args.to_string().replace(' ', "").as_str() {
"" | "Send" => true,
"?Send" => false,
_ => {
return syn::Error::new(Span::call_site(), "Only accepts `Send` or `?Send`")
.to_compile_error()
.into();
}
};
let mut item = parse_macro_input!(input as Item);
convert_async(&mut item, send).into()
}
/// convert marked async code to sync code
#[proc_macro_attribute]
pub fn must_be_sync(_args: TokenStream, input: TokenStream) -> TokenStream {
let mut item = parse_macro_input!(input as Item);
convert_sync(&mut item).into()
}
/// mark sync implementation
///
/// only compiled when `is_sync` feature gate is set.
/// When `is_sync` is not set, marked code is removed.
#[proc_macro_attribute]
pub fn sync_impl(_args: TokenStream, input: TokenStream) -> TokenStream {
let input = TokenStream2::from(input);
let token = if cfg!(feature = "is_sync") {
quote!(#input)
} else {
quote!()
};
token.into()
}
/// mark async implementation
///
/// only compiled when `is_sync` feature gate is not set.
/// When `is_sync` is set, marked code is removed.
#[proc_macro_attribute]
pub fn async_impl(args: TokenStream, _input: TokenStream) -> TokenStream {
let send = match args.to_string().replace(' ', "").as_str() {
"" | "Send" => true,
"?Send" => false,
_ => {
return syn::Error::new(Span::call_site(), "Only accepts `Send` or `?Send`")
.to_compile_error()
.into();
}
};
let token = if cfg!(feature = "is_sync") {
quote!()
} else {
let mut item = parse_macro_input!(_input as Item);
convert_async(&mut item, send)
};
token.into()
}
macro_rules! match_nested_meta_to_str_lit {
($t:expr) => {
match $t {
NestedMeta::Lit(lit) => {
match lit {
Lit::Str(s) => {
s.value().parse::<TokenStream2>().unwrap()
}
_ => {
return syn::Error::new(lit.span(), "expected meta or string literal").to_compile_error().into();
}
}
}
NestedMeta::Meta(meta) => quote!(#meta)
}
};
}
/// Handy macro to unify test code of sync and async code
///
/// Since the API of both sync and async code are the same,
/// with only difference that async functions must be awaited.
/// So it's tedious to write unit sync and async respectively.
///
/// This macro helps unify the sync and async unit test code.
/// Pass the condition to treat test code as sync as the first
/// argument. And specify the condition when to treat test code
/// as async and the lib to run async test, e.x. `async-std::test`,
/// `tokio::test`, or any valid attribute macro.
///
/// **ATTENTION**: do not write await inside a assert macro
///
/// - Examples
///
/// ```rust
/// #[maybe_async::maybe_async]
/// async fn async_fn() -> bool {
/// true
/// }
///
/// #[maybe_async::test(
/// // when to treat the test code as sync version
/// feature="is_sync",
/// // when to run async test
/// async(all(not(feature="is_sync"), feature="async_std"), async_std::test),
/// // you can specify multiple conditions for different async runtime
/// async(all(not(feature="is_sync"), feature="tokio"), tokio::test)
/// )]
/// async fn test_async_fn() {
/// let res = async_fn().await;
/// assert_eq!(res, true);
/// }
///
/// // Only run test in sync version
/// #[maybe_async::test(feature = "is_sync")]
/// async fn test_sync_fn() {
/// let res = async_fn().await;
/// assert_eq!(res, true);
/// }
/// ```
///
/// The above code is transcripted to the following code:
///
/// ```rust
/// # use maybe_async::{must_be_async, must_be_sync, sync_impl};
/// # #[maybe_async::maybe_async]
/// # async fn async_fn() -> bool { true }
///
/// // convert to sync version when sync condition is met, keep in async version when corresponding
/// // condition is met
/// #[cfg_attr(feature = "is_sync", must_be_sync, test)]
/// #[cfg_attr(
/// all(not(feature = "is_sync"), feature = "async_std"),
/// must_be_async,
/// async_std::test
/// )]
/// #[cfg_attr(
/// all(not(feature = "is_sync"), feature = "tokio"),
/// must_be_async,
/// tokio::test
/// )]
/// async fn test_async_fn() {
/// let res = async_fn().await;
/// assert_eq!(res, true);
/// }
///
/// // force converted to sync function, and only compile on sync condition
/// #[cfg(feature = "is_sync")]
/// #[test]
/// fn test_sync_fn() {
/// let res = async_fn();
/// assert_eq!(res, true);
/// }
/// ```
#[proc_macro_attribute]
pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
let attr_args = parse_macro_input!(args as AttributeArgs);
let input = TokenStream2::from(input);
if attr_args.is_empty() {
return syn::Error::new(
Span::call_site(),
"Arguments cannot be empty, at least specify the condition for sync code",
)
.to_compile_error()
.into();
}
// The first attributes indicates sync condition
let sync_cond = match_nested_meta_to_str_lit!(attr_args.first().unwrap());
let mut ts = quote!(#[cfg_attr(#sync_cond, maybe_async::must_be_sync, test)]);
// The rest attributes indicates async condition and async test macro
// only accepts in the forms of `async(cond, test_macro)`, but `cond` and
// `test_macro` can be either meta attributes or string literal
let mut async_token = Vec::new();
let mut async_conditions = Vec::new();
for async_meta in attr_args.into_iter().skip(1) {
match async_meta {
NestedMeta::Meta(meta) => match meta {
Meta::List(list) => {
let name = list.path.segments[0].ident.to_string();
if name.ne("async") {
return syn::Error::new(
list.path.span(),
format!("Unknown path: `{}`, must be `async`", name),
)
.to_compile_error()
.into();
}
if list.nested.len() == 2 {
let async_cond =
match_nested_meta_to_str_lit!(list.nested.first().unwrap());
let async_test = match_nested_meta_to_str_lit!(list.nested.last().unwrap());
let attr = quote!(
#[cfg_attr(#async_cond, maybe_async::must_be_async, #async_test)]
);
async_conditions.push(async_cond);
async_token.push(attr);
} else {
let msg = format!(
"Must pass two metas or string literals like `async(condition, \
async_test_macro)`, you passed {} metas.",
list.nested.len()
);
return syn::Error::new(list.span(), msg).to_compile_error().into();
}
}
_ => {
return syn::Error::new(
meta.span(),
"Must be list of metas like: `async(condition, async_test_macro)`",
)
.to_compile_error()
.into();
}
},
NestedMeta::Lit(lit) => {
return syn::Error::new(
lit.span(),
"Must be list of metas like: `async(condition, async_test_macro)`",
)
.to_compile_error()
.into();
}
};
}
async_token.into_iter().for_each(|t| ts.extend(t));
ts.extend(quote!( #input ));
if !async_conditions.is_empty() {
quote! {
#[cfg(any(#sync_cond, #(#async_conditions),*))]
#ts
}
} else {
quote! {
#[cfg(#sync_cond)]
#ts
}
}
.into()
}

View file

@ -0,0 +1,49 @@
use proc_macro2::Span;
use syn::{
parse::{discouraged::Speculative, Parse, ParseStream, Result},
Attribute, Error, ItemFn, ItemImpl, ItemStatic, ItemTrait,
};
pub enum Item {
Trait(ItemTrait),
Impl(ItemImpl),
Fn(ItemFn),
Static(ItemStatic),
}
macro_rules! fork {
($fork:ident = $input:ident) => {{
$fork = $input.fork();
&$fork
}};
}
impl Parse for Item {
fn parse(input: ParseStream) -> Result<Self> {
let attrs = input.call(Attribute::parse_outer)?;
let mut fork;
let item = if let Ok(mut item) = fork!(fork = input).parse::<ItemImpl>() {
if item.trait_.is_none() {
return Err(Error::new(Span::call_site(), "expected a trait impl"));
}
item.attrs = attrs;
Item::Impl(item)
} else if let Ok(mut item) = fork!(fork = input).parse::<ItemTrait>() {
item.attrs = attrs;
Item::Trait(item)
} else if let Ok(mut item) = fork!(fork = input).parse::<ItemFn>() {
item.attrs = attrs;
Item::Fn(item)
} else if let Ok(mut item) = fork!(fork = input).parse::<ItemStatic>() {
item.attrs = attrs;
Item::Static(item)
} else {
return Err(Error::new(
Span::call_site(),
"expected trait impl, trait or fn",
));
};
input.advance_to(&fork);
Ok(item)
}
}

View file

@ -0,0 +1,188 @@
use std::iter::FromIterator;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{
parse_quote,
punctuated::Punctuated,
visit_mut::{self, visit_item_mut, visit_path_segment_mut, VisitMut},
Expr, ExprBlock, File, GenericArgument, GenericParam, Item, PathArguments, PathSegment, Type,
TypeParamBound, WherePredicate,
};
pub struct ReplaceGenericType<'a> {
generic_type: &'a str,
arg_type: &'a PathSegment,
}
impl<'a> ReplaceGenericType<'a> {
pub fn new(generic_type: &'a str, arg_type: &'a PathSegment) -> Self {
Self {
generic_type,
arg_type,
}
}
pub fn replace_generic_type(item: &mut Item, generic_type: &'a str, arg_type: &'a PathSegment) {
let mut s = Self::new(generic_type, arg_type);
s.visit_item_mut(item);
}
}
impl<'a> VisitMut for ReplaceGenericType<'a> {
fn visit_item_mut(&mut self, i: &mut Item) {
if let Item::Fn(item_fn) = i {
// remove generic type from generics <T, F>
let args = item_fn
.sig
.generics
.params
.iter()
.filter(|param| {
if let GenericParam::Type(type_param) = &param {
!type_param.ident.to_string().eq(self.generic_type)
} else {
true
}
})
.collect::<Vec<_>>();
item_fn.sig.generics.params =
Punctuated::from_iter(args.into_iter().cloned().collect::<Vec<_>>());
// remove generic type from where clause
if let Some(where_clause) = &mut item_fn.sig.generics.where_clause {
let new_where_clause = where_clause
.predicates
.iter()
.filter(|predicate| {
if let WherePredicate::Type(predicate_type) = predicate {
if let Type::Path(p) = &predicate_type.bounded_ty {
!p.path.segments[0].ident.to_string().eq(self.generic_type)
} else {
true
}
} else {
true
}
})
.collect::<Vec<_>>();
where_clause.predicates = Punctuated::from_iter(
new_where_clause.into_iter().cloned().collect::<Vec<_>>(),
);
};
}
visit_item_mut(self, i)
}
fn visit_path_segment_mut(&mut self, i: &mut PathSegment) {
// replace generic type with target type
if i.ident.to_string().eq(&self.generic_type) {
*i = self.arg_type.clone();
}
visit_path_segment_mut(self, i);
}
}
pub struct AsyncAwaitRemoval;
impl AsyncAwaitRemoval {
pub fn remove_async_await(&mut self, item: TokenStream) -> TokenStream {
let mut syntax_tree: File = syn::parse(item.into()).unwrap();
self.visit_file_mut(&mut syntax_tree);
quote!(#syntax_tree)
}
}
impl VisitMut for AsyncAwaitRemoval {
fn visit_expr_mut(&mut self, node: &mut Expr) {
// Delegate to the default impl to visit nested expressions.
visit_mut::visit_expr_mut(self, node);
match node {
Expr::Await(expr) => *node = (*expr.base).clone(),
Expr::Async(expr) => {
let inner = &expr.block;
let sync_expr = if inner.stmts.len() == 1 {
// remove useless braces when there is only one statement
let stmt = &inner.stmts.get(0).unwrap();
// convert statement to Expr
parse_quote!(#stmt)
} else {
Expr::Block(ExprBlock {
attrs: expr.attrs.clone(),
block: inner.clone(),
label: None,
})
};
*node = sync_expr;
}
_ => {}
}
}
fn visit_item_mut(&mut self, i: &mut Item) {
// find generic parameter of Future and replace it with its Output type
if let Item::Fn(item_fn) = i {
let mut inputs: Vec<(String, PathSegment)> = vec![];
// generic params: <T:Future<Output=()>, F>
for param in &item_fn.sig.generics.params {
// generic param: T:Future<Output=()>
if let GenericParam::Type(type_param) = param {
let generic_type_name = type_param.ident.to_string();
// bound: Future<Output=()>
for bound in &type_param.bounds {
inputs.extend(search_trait_bound(&generic_type_name, bound));
}
}
}
if let Some(where_clause) = &item_fn.sig.generics.where_clause {
for predicate in &where_clause.predicates {
if let WherePredicate::Type(predicate_type) = predicate {
let generic_type_name = if let Type::Path(p) = &predicate_type.bounded_ty {
p.path.segments[0].ident.to_string()
} else {
panic!("Please submit an issue");
};
for bound in &predicate_type.bounds {
inputs.extend(search_trait_bound(&generic_type_name, bound));
}
}
}
}
for (generic_type_name, path_seg) in &inputs {
ReplaceGenericType::replace_generic_type(i, generic_type_name, path_seg);
}
}
visit_item_mut(self, i);
}
}
fn search_trait_bound(
generic_type_name: &str,
bound: &TypeParamBound,
) -> Vec<(String, PathSegment)> {
let mut inputs = vec![];
if let TypeParamBound::Trait(trait_bound) = bound {
let segment = &trait_bound.path.segments[trait_bound.path.segments.len() - 1];
let name = segment.ident.to_string();
if name.eq("Future") {
// match Future<Output=Type>
if let PathArguments::AngleBracketed(args) = &segment.arguments {
// binding: Output=Type
if let GenericArgument::Binding(binding) = &args.args[0] {
if let Type::Path(p) = &binding.ty {
inputs.push((generic_type_name.to_owned(), p.path.segments[0].clone()));
}
}
}
}
}
inputs
}

View file

@ -2,13 +2,15 @@
name = "store"
version = "0.1.0"
edition = "2021"
resolver = "2"
[dependencies]
utils = { path = "../utils" }
maybe-async = { path = "../maybe-async" }
rocksdb = { version = "0.20.1", optional = true }
foundationdb = { version = "0.7.0", optional = true }
rusqlite = { version = "0.29.0", features = ["bundled"], optional = true }
tokio = { version = "1.23", features = ["sync", "fs", "io-util"], optional = true }
tokio = { version = "1.23", features = ["sync", "fs", "io-util"] }
r2d2 = { version = "0.8.10", optional = true }
futures = { version = "0.3", optional = true }
rand = "0.8.5"
@ -25,7 +27,6 @@ jieba-rs = "0.6" # Chinese stemmer
xxhash-rust = { version = "0.8.5", features = ["xxh3"] }
farmhash = "1.1.5"
siphasher = "0.3"
maybe-async = "0.2"
parking_lot = { version = "0.12.1", optional = true }
lru-cache = { version = "0.1.2", optional = true }
blake3 = "1.3.3"
@ -33,7 +34,7 @@ blake3 = "1.3.3"
[features]
default = ["sqlite"]
rocks = ["rocksdb", "rayon", "is_sync"]
sqlite = ["rusqlite", "rayon", "r2d2", "tokio", "is_sync"]
sqlite = ["rusqlite", "rayon", "r2d2", "is_sync"]
foundation = ["foundationdb", "futures", "is_async", "key_subspace"]
is_sync = ["maybe-async/is_sync", "parking_lot", "lru-cache"]
is_async = []

View file

@ -1,6 +1,7 @@
use foundationdb::Database;
use utils::config::Config;
use crate::Store;
use crate::{blob::BlobStore, Store};
impl Store {
pub async fn open(config: &Config) -> crate::Result<Self> {

View file

@ -1,9 +1,6 @@
use foundationdb::FdbError;
use crate::{
write::key::KeySerializer, AclKey, BitmapKey, BlobKey, Error, IndexKey, IndexKeyPrefix, LogKey,
Serialize, ValueKey,
};
use crate::Error;
pub mod bitmap;
pub mod main;

View file

@ -13,10 +13,11 @@ use roaring::RoaringBitmap;
use crate::{
query::Operator,
write::key::{DeserializeBigEndian, KeySerializer},
BitmapKey, Deserialize, IndexKey, IndexKeyPrefix, ReadTransaction, Serialize, Store, ValueKey,
BitmapKey, Deserialize, IndexKey, IndexKeyPrefix, Key, ReadTransaction, Serialize, Store,
ValueKey, SUBSPACE_INDEXES,
};
use super::{bitmap::DeserializeBlock, SUBSPACE_INDEXES};
use super::bitmap::DeserializeBlock;
impl ReadTransaction<'_> {
#[inline(always)]
@ -231,6 +232,43 @@ impl ReadTransaction<'_> {
Ok(())
}
pub(crate) async fn iterate<T>(
&self,
mut acc: T,
begin: impl Key,
end: impl Key,
first: bool,
ascending: bool,
cb: impl Fn(&mut T, &[u8], &[u8]) -> crate::Result<bool> + Sync + Send + 'static,
) -> crate::Result<T> {
todo!()
}
pub(crate) async fn get_last_change_id(
&self,
account_id: u32,
collection: u8,
) -> crate::Result<Option<u64>> {
todo!()
/*let key = LogKey {
account_id,
collection,
change_id: u64::MAX,
}
.serialize();
self.conn
.prepare_cached("SELECT k FROM l WHERE k < ? ORDER BY k DESC LIMIT 1")?
.query_row([&key], |row| {
let key = row.get_ref(0)?.as_bytes()?;
key.deserialize_be_u64(key.len() - std::mem::size_of::<u64>())
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(err.into()))
})
.optional()
.map_err(Into::into)*/
}
pub async fn refresh_if_old(&mut self) -> crate::Result<()> {
if self.trx_age.elapsed() > Duration::from_millis(2000) {
self.trx = self.db.create_trx()?;

View file

@ -1,4 +1,4 @@
use std::time::{Duration, Instant, SystemTime};
use std::time::{Duration, Instant};
use ahash::AHashSet;
use foundationdb::{
@ -11,17 +11,14 @@ use rand::Rng;
use crate::{
write::{
key::{DeserializeBigEndian, KeySerializer},
Batch, Operation,
now, Batch, Operation,
},
AclKey, BitmapKey, BlobKey, Deserialize, IndexKey, LogKey, Serialize, Store, ValueKey,
BM_DOCUMENT_IDS,
};
use super::{
bitmap::{next_available_index, DenseBitmap, BITS_PER_BLOCK},
AclKey, BitmapKey, Deserialize, IndexKey, LogKey, Serialize, Store, ValueKey, BM_DOCUMENT_IDS,
SUBSPACE_VALUES,
};
use super::bitmap::{next_available_index, DenseBitmap, BITS_PER_BLOCK};
#[cfg(feature = "test_mode")]
const ID_ASSIGNMENT_EXPIRY: u64 = 2; // seconds
#[cfg(not(feature = "test_mode"))]
@ -119,28 +116,6 @@ impl Store {
trx.atomic_op(&key, &and_bitmap.bitmap, MutationType::BitAnd);
};
}
Operation::Blob { key, set } => {
let key = BlobKey {
account_id,
collection,
document_id,
hash: key,
}
.serialize();
if *set {
let now_;
let value = if document_id != u32::MAX {
&[]
} else {
now_ = now().to_be_bytes();
&now_[..]
};
trx.set(&key, value);
} else {
trx.clear(&key);
}
}
Operation::Acl {
grant_account_id,
set,
@ -171,6 +146,11 @@ impl Store {
.serialize();
trx.set(&key, set);
}
Operation::AssertValue {
field,
family,
assert_value,
} => todo!(),
}
}
@ -190,8 +170,13 @@ impl Store {
}
}
pub async fn assign_document_id(&self, account_id: u32, collection: u8) -> crate::Result<u32> {
pub async fn assign_document_id(
&self,
account_id: u32,
collection: impl Into<u8>,
) -> crate::Result<u32> {
let start = Instant::now();
let collection = collection.into();
loop {
//let mut assign_source = 0;
@ -338,8 +323,13 @@ impl Store {
}
}
pub async fn assign_change_id(&self, account_id: u32, collection: u8) -> crate::Result<u64> {
pub async fn assign_change_id(
&self,
account_id: u32,
collection: impl Into<u8>,
) -> crate::Result<u64> {
let start = Instant::now();
let collection = collection.into();
let counter = KeySerializer::new(std::mem::size_of::<u32>() + 2)
.write(SUBSPACE_VALUES)
.write_leb128(account_id)
@ -371,7 +361,6 @@ impl Store {
}
}
#[cfg(test)]
pub async fn destroy(&self) {
let trx = self.db.create_trx().unwrap();
trx.clear_range(&[0u8], &[u8::MAX]);

View file

@ -4,7 +4,7 @@ use lru_cache::LruCache;
use parking_lot::Mutex;
use r2d2::Pool;
use tokio::sync::oneshot;
use utils::config::Config;
use utils::{config::Config, UnwrapFailure};
use crate::{
blob::BlobStore, Store, SUBSPACE_ACLS, SUBSPACE_BITMAPS, SUBSPACE_BLOBS, SUBSPACE_INDEXES,
@ -20,7 +20,12 @@ impl Store {
pub async fn open(config: &Config) -> crate::Result<Self> {
let db = Self {
conn_pool: Pool::new(
SqliteConnectionManager::file("/tmp/sqlite.db").with_init(|c| {
SqliteConnectionManager::file(
config
.value_require("store.db.path")
.failed("Invalid configuration file"),
)
.with_init(|c| {
c.execute_batch(concat!(
"PRAGMA journal_mode = WAL; ",
"PRAGMA synchronous = normal; ",

View file

@ -16,7 +16,7 @@ pub enum BlobStore {
impl BlobStore {
pub async fn new(config: &Config) -> crate::Result<Self> {
Ok(BlobStore::Local(
config.value_require("blob.store.path")?.into(),
config.value_require("store.blob.path")?.into(),
))
}
}

View file

@ -22,8 +22,8 @@
*/
use crate::{
write::{IntoBitmap, Operation},
BitmapKey, BM_HASH,
write::{BitmapFamily, Operation},
BitmapKey, Serialize, BM_HASH,
};
use self::{bloom::hash_token, builder::MAX_TOKEN_MASK};
@ -186,16 +186,15 @@ impl BitmapKey<Vec<u8>> {
account_id: u32,
collection: impl Into<u8>,
field: impl Into<u8>,
value: impl IntoBitmap,
value: impl BitmapFamily + Serialize,
) -> Self {
let (key, family) = value.into_bitmap();
BitmapKey {
account_id,
collection: collection.into(),
family,
family: value.family(),
field: field.into(),
block_num: 0,
key,
key: value.serialize(),
}
}
}

View file

@ -161,9 +161,8 @@ impl From<String> for Error {
}
pub const BM_DOCUMENT_IDS: u8 = 0;
pub const BM_KEYWORD: u8 = 1 << 5;
pub const BM_TAG: u8 = 1 << 6;
pub const BM_HASH: u8 = 1 << 7;
pub const BM_TAG: u8 = 1 << 5;
pub const BM_HASH: u8 = 1 << 6;
pub const HASH_EXACT: u8 = 0;
pub const HASH_STEMMED: u8 = 1 << 6;

View file

@ -7,8 +7,8 @@ use roaring::RoaringBitmap;
use crate::{
fts::{lang::LanguageDetector, Language},
write::IntoBitmap,
BitmapKey, Serialize, BM_DOCUMENT_IDS, BM_KEYWORD,
write::BitmapFamily,
BitmapKey, Serialize, BM_DOCUMENT_IDS,
};
#[derive(Debug, Clone, Copy)]
@ -119,14 +119,6 @@ impl Filter {
}
}
pub fn has_keyword(field: impl Into<u8>, value: impl Serialize) -> Self {
Filter::InBitmap {
family: BM_KEYWORD,
field: field.into(),
key: value.serialize(),
}
}
pub fn has_text(field: impl Into<u8>, text: impl Into<String>, mut language: Language) -> Self {
let mut text = text.into();
let op = if !matches!(language, Language::None) {
@ -167,12 +159,11 @@ impl Filter {
Self::has_text(field, text, Language::English)
}
pub fn is_in_bitmap(field: impl Into<u8>, value: impl IntoBitmap) -> Self {
let (key, family) = value.into_bitmap();
pub fn is_in_bitmap(field: impl Into<u8>, value: impl BitmapFamily + Serialize) -> Self {
Self::InBitmap {
family,
family: value.family(),
field: field.into(),
key,
key: value.serialize(),
}
}

View file

@ -73,8 +73,9 @@ impl ReadTransaction<'_> {
let mut sorted_results = paginate.build();
if let Some(prefix_key) = prefix_key {
for id in sorted_results.ids.iter_mut() {
if let Some(prefix_id) =
self.get_value::<u32>(prefix_key.with_document_id(*id as u32))?
if let Some(prefix_id) = self
.get_value::<u32>(prefix_key.with_document_id(*id as u32))
.await?
{
*id |= (prefix_id as u64) << 32;
}
@ -158,8 +159,9 @@ impl ReadTransaction<'_> {
for (document_id, _) in sorted_ids {
// Obtain document prefixId
let prefix_id = if let Some(prefix_key) = &paginate.prefix_key {
if let Some(prefix_id) =
self.get_value(prefix_key.with_document_id(document_id))?
if let Some(prefix_id) = self
.get_value(prefix_key.with_document_id(document_id))
.await?
{
if paginate.prefix_unique && !seen_prefixes.insert(prefix_id) {
continue;

View file

@ -1,8 +1,8 @@
use crate::BM_DOCUMENT_IDS;
use super::{
Batch, BatchBuilder, HasFlag, IntoBitmap, IntoOperations, Operation, Serialize, ToAssertValue,
ToBitmaps, F_BITMAP, F_CLEAR, F_INDEX, F_VALUE,
Batch, BatchBuilder, BitmapFamily, HasFlag, IntoOperations, Operation, Serialize,
ToAssertValue, ToBitmaps, F_BITMAP, F_CLEAR, F_INDEX, F_VALUE,
};
impl BatchBuilder {
@ -104,14 +104,13 @@ impl BatchBuilder {
pub fn bitmap(
&mut self,
field: impl Into<u8>,
value: impl IntoBitmap,
value: impl BitmapFamily + Serialize,
options: u32,
) -> &mut Self {
let (key, family) = value.into_bitmap();
self.ops.push(Operation::Bitmap {
family,
family: value.family(),
field: field.into(),
key,
key: value.serialize(),
set: !options.has_flag(F_CLEAR),
});
self

View file

@ -143,35 +143,43 @@ impl ValueKey {
impl<T: AsRef<[u8]>> Serialize for &IndexKey<T> {
fn serialize(self) -> Vec<u8> {
let key = self.key.as_ref();
#[cfg(feature = "key_subspace")]
{
KeySerializer::new(std::mem::size_of::<IndexKey<T>>() + key.len() + 1)
.write(crate::SUBSPACE_INDEXES)
#[cfg(feature = "key_subspace")]
{
KeySerializer::new(std::mem::size_of::<IndexKey<T>>() + key.len() + 1)
.write(crate::SUBSPACE_INDEXES)
}
#[cfg(not(feature = "key_subspace"))]
{
KeySerializer::new(std::mem::size_of::<IndexKey<T>>() + key.len())
}
}
#[cfg(not(feature = "key_subspace"))]
{ KeySerializer::new(std::mem::size_of::<IndexKey<T>>() + key.len()) }
.write(self.account_id)
.write(self.collection)
.write(self.field)
.write(key)
.write(self.document_id)
.finalize()
.write(self.account_id)
.write(self.collection)
.write(self.field)
.write(key)
.write(self.document_id)
.finalize()
}
}
impl Serialize for &IndexKeyPrefix {
fn serialize(self) -> Vec<u8> {
#[cfg(feature = "key_subspace")]
{
KeySerializer::new(std::mem::size_of::<IndexKeyPrefix>() + 1)
.write(crate::SUBSPACE_INDEXES)
#[cfg(feature = "key_subspace")]
{
KeySerializer::new(std::mem::size_of::<IndexKeyPrefix>() + 1)
.write(crate::SUBSPACE_INDEXES)
}
#[cfg(not(feature = "key_subspace"))]
{
KeySerializer::new(std::mem::size_of::<IndexKeyPrefix>())
}
}
#[cfg(not(feature = "key_subspace"))]
{ KeySerializer::new(std::mem::size_of::<IndexKeyPrefix>()) }
.write(self.account_id)
.write(self.collection)
.write(self.field)
.finalize()
.write(self.account_id)
.write(self.collection)
.write(self.field)
.finalize()
}
}
@ -206,51 +214,63 @@ impl Serialize for &ValueKey {
impl<T: AsRef<[u8]>> Serialize for &BitmapKey<T> {
fn serialize(self) -> Vec<u8> {
let key = self.key.as_ref();
#[cfg(feature = "key_subspace")]
{
KeySerializer::new(std::mem::size_of::<BitmapKey<T>>() + key.len() + 1)
.write(crate::SUBSPACE_BITMAPS)
#[cfg(feature = "key_subspace")]
{
KeySerializer::new(std::mem::size_of::<BitmapKey<T>>() + key.len() + 1)
.write(crate::SUBSPACE_BITMAPS)
}
#[cfg(not(feature = "key_subspace"))]
{
KeySerializer::new(std::mem::size_of::<BitmapKey<T>>() + key.len())
}
}
#[cfg(not(feature = "key_subspace"))]
{ KeySerializer::new(std::mem::size_of::<BitmapKey<T>>() + key.len()) }
.write(self.account_id)
.write(self.collection)
.write(self.family)
.write(self.field)
.write(key)
.write(self.block_num)
.finalize()
.write(self.account_id)
.write(self.collection)
.write(self.family)
.write(self.field)
.write(key)
.write(self.block_num)
.finalize()
}
}
impl Serialize for &AclKey {
fn serialize(self) -> Vec<u8> {
#[cfg(feature = "key_subspace")]
{
KeySerializer::new(std::mem::size_of::<AclKey>() + 1).write(crate::SUBSPACE_ACLS)
#[cfg(feature = "key_subspace")]
{
KeySerializer::new(std::mem::size_of::<AclKey>() + 1).write(crate::SUBSPACE_ACLS)
}
#[cfg(not(feature = "key_subspace"))]
{
KeySerializer::new(std::mem::size_of::<AclKey>())
}
}
#[cfg(not(feature = "key_subspace"))]
{ KeySerializer::new(std::mem::size_of::<AclKey>()) }
.write_leb128(self.grant_account_id)
.write(u8::MAX)
.write_leb128(self.to_account_id)
.write(self.to_collection)
.write_leb128(self.to_document_id)
.finalize()
.write_leb128(self.grant_account_id)
.write(u8::MAX)
.write_leb128(self.to_account_id)
.write(self.to_collection)
.write_leb128(self.to_document_id)
.finalize()
}
}
impl Serialize for &LogKey {
fn serialize(self) -> Vec<u8> {
#[cfg(feature = "key_subspace")]
{
KeySerializer::new(std::mem::size_of::<LogKey>() + 1).write(crate::SUBSPACE_LOGS)
#[cfg(feature = "key_subspace")]
{
KeySerializer::new(std::mem::size_of::<LogKey>() + 1).write(crate::SUBSPACE_LOGS)
}
#[cfg(not(feature = "key_subspace"))]
{
KeySerializer::new(std::mem::size_of::<LogKey>())
}
}
#[cfg(not(feature = "key_subspace"))]
{ KeySerializer::new(std::mem::size_of::<LogKey>()) }
.write(self.account_id)
.write(self.collection)
.write(self.change_id)
.finalize()
.write(self.account_id)
.write(self.collection)
.write(self.change_id)
.finalize()
}
}

View file

@ -1,8 +1,10 @@
use std::{collections::HashSet, time::SystemTime};
use std::{collections::HashSet, slice::Iter, time::SystemTime};
use utils::codec::leb128::{Leb128Iterator, Leb128Vec};
use crate::{
fts::{builder::MAX_TOKEN_LENGTH, tokenizers::space::SpaceTokenizer},
Deserialize, Serialize, BM_TAG, HASH_EXACT, TAG_ID, TAG_STATIC, TAG_TEXT,
Deserialize, Serialize, BM_TAG, HASH_EXACT, TAG_ID, TAG_STATIC,
};
pub mod batch;
@ -137,6 +139,85 @@ impl Deserialize for u32 {
}
}
pub trait SerializeInto {
fn serialize_into(&self, buf: &mut Vec<u8>);
}
pub trait DeserializeFrom: Sized {
fn deserialize_from(bytes: &mut Iter<'_, u8>) -> Option<Self>;
}
impl<T: SerializeInto> Serialize for Vec<T> {
fn serialize(self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(self.len() * 4);
bytes.push_leb128(self.len());
for item in self {
item.serialize_into(&mut bytes);
}
bytes
}
}
impl SerializeInto for String {
fn serialize_into(&self, buf: &mut Vec<u8>) {
buf.push_leb128(self.len());
if !self.is_empty() {
buf.extend_from_slice(self.as_bytes());
}
}
}
impl SerializeInto for u32 {
fn serialize_into(&self, buf: &mut Vec<u8>) {
buf.push_leb128(*self);
}
}
impl SerializeInto for u64 {
fn serialize_into(&self, buf: &mut Vec<u8>) {
buf.push_leb128(*self);
}
}
impl DeserializeFrom for u32 {
fn deserialize_from(bytes: &mut Iter<'_, u8>) -> Option<Self> {
bytes.next_leb128()
}
}
impl DeserializeFrom for u64 {
fn deserialize_from(bytes: &mut Iter<'_, u8>) -> Option<Self> {
bytes.next_leb128()
}
}
impl DeserializeFrom for String {
fn deserialize_from(bytes: &mut Iter<'_, u8>) -> Option<Self> {
let len: usize = bytes.next_leb128()?;
let mut s = Vec::with_capacity(len);
for _ in 0..len {
s.push(*bytes.next()?);
}
String::from_utf8(s).ok()
}
}
impl<T: DeserializeFrom + Sync + Send> Deserialize for Vec<T> {
fn deserialize(bytes: &[u8]) -> crate::Result<Self> {
let mut bytes = bytes.iter();
let len: usize = bytes
.next_leb128()
.ok_or_else(|| crate::Error::InternalError("Failed to deserialize Vec".to_string()))?;
let mut list = Vec::with_capacity(len);
for _ in 0..len {
list.push(T::deserialize_from(&mut bytes).ok_or_else(|| {
crate::Error::InternalError("Failed to deserialize Vec".to_string())
})?);
}
Ok(list)
}
}
trait HasFlag {
fn has_flag(&self, flag: u32) -> bool;
}
@ -184,8 +265,13 @@ impl ToBitmaps for u32 {
}
impl ToBitmaps for u64 {
fn to_bitmaps(&self, _ops: &mut Vec<Operation>, _field: u8, _set: bool) {
unreachable!()
fn to_bitmaps(&self, ops: &mut Vec<Operation>, field: u8, set: bool) {
ops.push(Operation::Bitmap {
family: BM_TAG | TAG_ID,
field,
key: (*self as u32).serialize(),
set,
});
}
}
@ -195,31 +281,33 @@ impl ToBitmaps for f64 {
}
}
pub trait IntoBitmap {
fn into_bitmap(self) -> (Vec<u8>, u8);
}
impl IntoBitmap for () {
fn into_bitmap(self) -> (Vec<u8>, u8) {
(vec![], BM_TAG | TAG_STATIC)
impl<T: ToBitmaps> ToBitmaps for Vec<T> {
fn to_bitmaps(&self, ops: &mut Vec<Operation>, field: u8, set: bool) {
for item in self {
item.to_bitmaps(ops, field, set);
}
}
}
impl IntoBitmap for u32 {
fn into_bitmap(self) -> (Vec<u8>, u8) {
(self.serialize(), BM_TAG | TAG_ID)
pub trait BitmapFamily {
fn family(&self) -> u8;
}
impl BitmapFamily for () {
fn family(&self) -> u8 {
BM_TAG | TAG_STATIC
}
}
impl IntoBitmap for String {
fn into_bitmap(self) -> (Vec<u8>, u8) {
(self.serialize(), BM_TAG | TAG_TEXT)
impl BitmapFamily for u32 {
fn family(&self) -> u8 {
BM_TAG | TAG_ID
}
}
impl IntoBitmap for &str {
fn into_bitmap(self) -> (Vec<u8>, u8) {
(self.serialize(), BM_TAG | TAG_TEXT)
impl Serialize for () {
fn serialize(self) -> Vec<u8> {
Vec::with_capacity(0)
}
}

View file

@ -2,11 +2,15 @@
name = "utils"
version = "0.1.0"
edition = "2021"
resolver = "2"
[dependencies]
rustls = "0.21.0"
rustls-pemfile = "1.0"
tokio = { version = "1.23", features = ["net"] }
tokio = { version = "1.23", features = ["net", "macros"] }
tokio-rustls = { version = "0.24.0"}
serde = { version = "1.0", features = ["derive"]}
tracing = "0.1"
[target.'cfg(unix)'.dependencies]
privdrop = "0.5.3"

View file

@ -4,7 +4,7 @@ use tokio::{net::TcpListener, sync::watch};
use tokio_rustls::TlsAcceptor;
use crate::{
config::{Listener, Server, ServerProtocol, Servers},
config::{Config, Listener, Server, ServerProtocol, Servers},
failed,
listener::SessionData,
UnwrapFailure,
@ -13,12 +13,7 @@ use crate::{
use super::{limiter::ConcurrencyLimiter, ServerInstance, SessionManager};
impl Server {
pub fn spawn(
self,
manager: impl SessionManager,
max_concurrent: u64,
shutdown_rx: watch::Receiver<bool>,
) -> Result<(), String> {
pub fn spawn(self, manager: impl SessionManager, shutdown_rx: watch::Receiver<bool>) {
// Prepare instance
let instance = Arc::new(ServerInstance {
data: if matches!(self.protocol, ServerProtocol::Smtp | ServerProtocol::Lmtp) {
@ -32,11 +27,10 @@ impl Server {
hostname: self.hostname,
tls_acceptor: self.tls.map(|config| TlsAcceptor::from(Arc::new(config))),
is_tls_implicit: self.tls_implicit,
limiter: ConcurrencyLimiter::new(manager.max_concurrent()),
shutdown_rx,
});
// Start concurrency limiter
let limiter = Arc::new(ConcurrencyLimiter::new(max_concurrent));
// Spawn listeners
for listener in self.listeners {
tracing::info!(
@ -53,10 +47,9 @@ impl Server {
let listener = listener.listen();
// Spawn listener
let mut shutdown_rx = shutdown_rx.clone();
let mut shutdown_rx = instance.shutdown_rx.clone();
let manager = manager.clone();
let instance = instance.clone();
let limiter = limiter.clone();
tokio::spawn(async move {
loop {
tokio::select! {
@ -64,7 +57,7 @@ impl Server {
match stream {
Ok((stream, remote_addr)) => {
// Enforce concurrency
if let Some(in_flight) = limiter.is_allowed() {
if let Some(in_flight) = instance.limiter.is_allowed() {
let span = tracing::info_span!(
"session",
instance = instance.id,
@ -81,7 +74,6 @@ impl Server {
span,
in_flight,
instance: instance.clone(),
shutdown_rx: shutdown_rx.clone(),
});
} else {
tracing::info!(
@ -91,7 +83,7 @@ impl Server {
protocol = ?instance.protocol,
remote.ip = remote_addr.ip().to_string(),
remote.port = remote_addr.port(),
max_concurrent = max_concurrent,
max_concurrent = instance.limiter.max_concurrent,
"Too many concurrent connections."
);
};
@ -117,13 +109,16 @@ impl Server {
}
});
}
Ok(())
}
}
impl Servers {
pub fn bind(&self) {
pub fn spawn(
self,
config: &Config,
spawn: impl Fn(Server, watch::Receiver<bool>),
) -> watch::Sender<bool> {
// Bind as root
for server in &self.inner {
for listener in &server.listeners {
listener
@ -132,6 +127,26 @@ impl Servers {
.failed(&format!("Failed to bind to {}", listener.addr));
}
}
// Drop privileges
#[cfg(not(target_env = "msvc"))]
{
if let Some(run_as_user) = config.value("server.run-as.user") {
let mut pd = privdrop::PrivDrop::default().user(run_as_user);
if let Some(run_as_group) = config.value("server.run-as.group") {
pd = pd.group(run_as_group);
}
pd.apply().failed("Failed to drop privileges");
}
}
// Spawn listeners
let (shutdown_tx, shutdown_rx) = watch::channel(false);
for server in self.inner {
spawn(server, shutdown_rx.clone());
}
shutdown_tx
}
}

View file

@ -9,7 +9,7 @@ use tokio_rustls::TlsAcceptor;
use crate::config::ServerProtocol;
use self::limiter::InFlight;
use self::limiter::{ConcurrencyLimiter, InFlight};
pub mod limiter;
pub mod listen;
@ -22,6 +22,8 @@ pub struct ServerInstance {
pub data: String,
pub tls_acceptor: Option<TlsAcceptor>,
pub is_tls_implicit: bool,
pub limiter: ConcurrencyLimiter,
pub shutdown_rx: watch::Receiver<bool>,
}
pub struct SessionData<T: AsyncRead + AsyncWrite + Unpin + 'static> {
@ -31,9 +33,9 @@ pub struct SessionData<T: AsyncRead + AsyncWrite + Unpin + 'static> {
pub span: tracing::Span,
pub in_flight: InFlight,
pub instance: Arc<ServerInstance>,
pub shutdown_rx: watch::Receiver<bool>,
}
pub trait SessionManager: Sync + Send + 'static + Clone {
fn spawn(&self, session: SessionData<TcpStream>);
fn max_concurrent(&self) -> u64;
}

View file

@ -0,0 +1,109 @@
{
"mailboxIds": {
"a": true
},
"keywords": {
"tag": true
},
"size": 2651,
"receivedAt": "2054-01-02T20:53:20Z",
"from": [
{
"name": "Al Gore",
"email": "vice-president@whitehouse.gov"
}
],
"to": [
{
"name": "White House Transportation Coordinator",
"email": "transport@whitehouse.gov"
}
],
"subject": "[Fwd: Map of Argentina with Description]",
"bodyStructure": {
"partId": "0",
"headers": [
{
"name": "From",
"value": " Al Gore <vice-president@whitehouse.gov>"
},
{
"name": "To",
"value": " White House Transportation Coordinator\n <transport@whitehouse.gov>"
},
{
"name": "Subject",
"value": " [Fwd: Map of Argentina with Description]"
},
{
"name": "Content-Type",
"value": " multipart/mixed;\n boundary=\"D7F------------D7FD5A0B8AB9C65CCDBFA872\""
}
],
"type": "multipart/mixed"
},
"bodyValues": {},
"textBody": [
{
"partId": "1",
"blobId": "blob_0",
"size": 61,
"headers": [
{
"name": "Content-Type",
"value": " text/plain; charset=us-ascii"
},
{
"name": "Content-Transfer-Encoding",
"value": " 7bit"
}
],
"type": "text/plain",
"charset": "us-ascii"
}
],
"htmlBody": [
{
"partId": "1",
"blobId": "blob_0",
"size": 61,
"headers": [
{
"name": "Content-Type",
"value": " text/plain; charset=us-ascii"
},
{
"name": "Content-Transfer-Encoding",
"value": " 7bit"
}
],
"type": "text/plain",
"charset": "us-ascii"
}
],
"attachments": [
{
"partId": "2",
"blobId": "blob_2",
"size": 1979,
"headers": [
{
"name": "Content-Type",
"value": " message/rfc822"
},
{
"name": "Content-Transfer-Encoding",
"value": " 7bit"
},
{
"name": "Content-Disposition",
"value": " inline"
}
],
"type": "message/rfc822",
"disposition": "inline"
}
],
"hasAttachment": true,
"preview": "Fred,\n\nFire up Air Force One! We're going South!\n\nThanks,\nAl"
}