This commit is contained in:
mdecimus 2025-03-09 17:17:52 +01:00
parent 1f950b40de
commit 79ad95fae1
4 changed files with 300 additions and 49 deletions

View file

@ -4,21 +4,25 @@
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
*/
use std::collections::HashMap;
use common::KV_LOCK_DAV;
use common::{Server, auth::AccessToken};
use dav_proto::schema::property::LockScope;
use dav_proto::schema::request::DeadProperty;
use dav_proto::{Depth, Timeout};
use dav_proto::schema::property::{ActiveLock, LockScope, WebDavProperty};
use dav_proto::schema::request::{DavPropertyValue, DeadProperty};
use dav_proto::schema::response::{BaseCondition, List, PropResponse};
use dav_proto::{Depth, ResourceState, Timeout};
use dav_proto::{RequestHeaders, schema::request::LockInfo};
use http_proto::HttpResponse;
use hyper::StatusCode;
use store::dispatch::lookup::KeyValue;
use store::write::{Archive, Archiver};
use store::{Serialize, blake3};
use store::write::serialize::rkyv_deserialize;
use store::write::{Archive, Archiver, now};
use store::{Serialize, U32_LEN};
use trc::AddContext;
use super::uri::{DavUriResource, UriResource};
use crate::DavError;
use crate::{DavError, DavErrorCondition};
pub(crate) trait LockRequestHandler: Sync + Send {
fn handle_lock_request(
@ -40,56 +44,156 @@ impl LockRequestHandler for Server {
let resource_hash = resource
.lock_key()
.ok_or(DavError::Code(StatusCode::CONFLICT))?;
let resource_path = resource
.resource
.ok_or(DavError::Code(StatusCode::CONFLICT))?;
if !access_token.is_member(resource.account_id.unwrap()) {
return Err(DavError::Code(StatusCode::FORBIDDEN));
}
let lock_data = if let Some(lock_data) = self
let mut lock_data = if let Some(lock_data) = self
.in_memory_store()
.key_get::<Archive>(resource_hash.as_slice())
.await
.caused_by(trc::location!())?
{
let lock_data = lock_data
.deserialize::<LockData>()
.unarchive::<LockData>()
.caused_by(trc::location!())?;
if access_token.primary_id == lock_data.owner {
Some(lock_data)
} else {
return Err(DavError::Code(StatusCode::LOCKED));
if let Some((lock_path, lock_item)) = lock_data.find_lock(resource_path) {
if !lock_item.is_lock_owner(access_token) {
return Err(DavErrorCondition::new(
StatusCode::LOCKED,
BaseCondition::LockTokenSubmitted(List(vec![
headers.format_to_base_uri(lock_path).into(),
])),
)
.into());
} else if headers.has_if()
&& !headers.eval_if(&[ResourceState {
resource: None,
etag: String::new(),
state_token: lock_item.uuid(),
}])
{
return Err(DavErrorCondition::new(
StatusCode::PRECONDITION_FAILED,
BaseCondition::LockTokenMatchesRequestUri,
)
.into());
}
} else if lock_info.is_some() {
if let Some((lock_path, lock_item)) = lock_data.can_lock(resource_path) {
if !lock_item.is_lock_owner(access_token) {
return Err(DavErrorCondition::new(
StatusCode::LOCKED,
BaseCondition::LockTokenSubmitted(List(vec![
headers.format_to_base_uri(lock_path).into(),
])),
)
.into());
} else if headers.has_if()
&& !headers.eval_if(&[ResourceState {
resource: None,
etag: String::new(),
state_token: lock_item.uuid(),
}])
{
return Err(DavErrorCondition::new(
StatusCode::PRECONDITION_FAILED,
BaseCondition::LockTokenMatchesRequestUri,
)
.into());
}
}
}
rkyv_deserialize(lock_data).caused_by(trc::location!())?
} else if lock_info.is_some() {
LockData::default()
} else {
None
return Err(DavErrorCondition::new(
StatusCode::CONFLICT,
BaseCondition::LockTokenMatchesRequestUri,
)
.into());
};
if let Some(lock_info) = lock_info {
let now = now();
let response = if let Some(lock_info) = lock_info {
let timeout = if let Timeout::Second(seconds) = headers.timeout {
std::cmp::min(seconds, self.core.dav.max_lock_timeout)
} else {
self.core.dav.max_lock_timeout
};
let lock_data = if let Some(mut lock_data) = lock_data {
lock_data.depth_infinity = matches!(headers.depth, Depth::Infinity);
lock_data.owner_dav = lock_info.owner;
lock_data.exclusive = matches!(lock_info.lock_scope, LockScope::Exclusive);
lock_data
} else {
LockData {
owner: access_token.primary_id,
depth_infinity: matches!(headers.depth, Depth::Infinity),
owner_dav: lock_info.owner,
exclusive: matches!(lock_info.lock_scope, LockScope::Exclusive),
}
let lock_item = LockItem {
owner: access_token.primary_id,
depth_infinity: matches!(headers.depth, Depth::Infinity),
owner_dav: lock_info.owner,
exclusive: matches!(lock_info.lock_scope, LockScope::Exclusive),
lock_id: store::rand::random(),
expires: now + timeout,
};
if lock_data
if lock_item
.owner_dav
.as_ref()
.is_some_and(|o| o.size() > self.core.dav.dead_property_size.unwrap_or(512))
{
return Err(DavError::Code(StatusCode::PAYLOAD_TOO_LARGE));
}
let active_lock = lock_item.to_active_lock(headers.format_to_base_uri(resource_path));
lock_data.locks.insert(resource_path.to_string(), lock_item);
HttpResponse::new(StatusCode::CREATED)
.with_lock_token(&active_lock.lock_token.as_ref().unwrap().0)
.with_xml_body(
PropResponse::new(vec![DavPropertyValue::new(
WebDavProperty::LockDiscovery,
vec![active_lock],
)])
.to_string(),
)
} else {
let lock_token = headers
.lock_token
.ok_or(DavError::Code(StatusCode::BAD_REQUEST))?;
let mut found_path = None;
for (lock_path, lock_item) in lock_data.locks.iter() {
if lock_item.uuid() == lock_token {
if lock_item.is_lock_owner(access_token) {
found_path = Some(lock_path.to_string());
break;
} else {
return Err(DavError::Code(StatusCode::FORBIDDEN));
}
}
}
if let Some(found_path) = found_path {
lock_data.locks.remove(&found_path);
HttpResponse::new(StatusCode::NO_CONTENT)
} else {
return Err(DavErrorCondition::new(
StatusCode::CONFLICT,
BaseCondition::LockTokenMatchesRequestUri,
)
.into());
}
};
// Remove expired locks
let mut max_expire = 0;
lock_data.locks.retain(|_, lock| {
if lock.expires > now {
max_expire = std::cmp::max(max_expire, lock.expires);
true
} else {
false
}
});
if !lock_data.locks.is_empty() {
self.in_memory_store()
.key_set(
KeyValue::new(
@ -98,39 +202,154 @@ impl LockRequestHandler for Server {
.serialize()
.caused_by(trc::location!())?,
)
.expires(timeout),
.expires(max_expire),
)
.await
.caused_by(trc::location!())?;
} else if lock_data.is_some() {
} else {
self.in_memory_store()
.key_delete(resource_hash.as_slice())
.key_delete(resource_hash)
.await
.caused_by(trc::location!())?;
}
todo!()
Ok(response)
}
}
#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
#[derive(Debug, Default, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
struct LockData {
locks: HashMap<String, LockItem>,
}
#[derive(Debug, Clone, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
struct LockItem {
lock_id: u64,
owner: u32,
expires: u64,
depth_infinity: bool,
exclusive: bool,
owner_dav: Option<DeadProperty>,
}
impl LockItem {
pub fn to_active_lock(&self, href: String) -> ActiveLock {
ActiveLock::new(
href,
if self.exclusive {
LockScope::Exclusive
} else {
LockScope::Shared
},
)
.with_depth(if self.depth_infinity {
Depth::Infinity
} else {
Depth::Zero
})
.with_owner_opt(self.owner_dav.clone())
.with_timeout(self.expires.saturating_sub(now()))
.with_lock_token(self.uuid())
}
pub fn uuid(&self) -> String {
let lock_id_high = (self.lock_id >> 32) as u32;
let lock_id_low = self.lock_id as u32;
let expires_high = (self.expires >> 48) as u16;
let expires_low = ((self.expires >> 16) & 0xFFFF) as u16;
format!(
"urn:uuid:{:08x}-{:04x}-{:04x}-{:04x}-{:04x}{:04x}{:04x}",
lock_id_high,
lock_id_low >> 16,
lock_id_low & 0xFFFF,
self.owner >> 16,
self.owner & 0xFFFF,
expires_high,
expires_low
)
}
}
impl ArchivedLockData {
pub fn find_lock<'x, 'y>(
&'x self,
resource: &'y str,
) -> Option<(&'y str, &'x ArchivedLockItem)> {
let now = now();
let mut resource_part = resource;
loop {
if let Some(lock) = self.locks.get(resource_part).filter(|lock| {
lock.expires > now && (resource == resource_part || lock.depth_infinity)
}) {
return Some((resource_part, lock));
} else if let Some((resource_part_, _)) = resource_part.rsplit_once('/') {
resource_part = resource_part_;
} else {
return None;
}
}
}
pub fn can_lock<'x>(&'x self, resource: &'x str) -> Option<(&'x str, &'x ArchivedLockItem)> {
if let Some(lock) = self.find_lock(resource) {
Some(lock)
} else {
let now = now();
self.locks.iter().find_map(|(resource_part, lock)| {
if lock.depth_infinity
&& lock.expires > now
&& resource_part
.strip_prefix(resource)
.is_some_and(|v| v.starts_with('/'))
{
Some((resource_part.as_str(), lock))
} else {
None
}
})
}
}
}
impl ArchivedLockItem {
#[inline]
pub fn is_lock_owner(&self, access_token: &AccessToken) -> bool {
self.owner == access_token.primary_id
}
pub fn uuid(&self) -> String {
let lock_id_high = (self.lock_id >> 32) as u32;
let lock_id_low = u64::from(self.lock_id) as u32;
let expires_high = (self.expires >> 48) as u16;
let expires_low = ((self.expires >> 16) & 0xFFFF) as u16;
format!(
"urn:uuid:{:08x}-{:04x}-{:04x}-{:04x}-{:04x}{:04x}{:04x}",
lock_id_high,
lock_id_low >> 16,
lock_id_low & 0xFFFF,
self.owner >> 16,
self.owner & 0xFFFF,
expires_high,
expires_low
)
}
}
impl LockItem {
#[inline]
pub fn is_lock_owner(&self, access_token: &AccessToken) -> bool {
self.owner == access_token.primary_id
}
}
impl UriResource<Option<&str>> {
pub fn lock_key(&self) -> Option<Vec<u8>> {
let mut hasher = blake3::Hasher::new();
hasher.update(self.resource?.as_bytes());
hasher.update(self.account_id?.to_be_bytes().as_slice());
hasher.update(u8::from(self.collection).to_be_bytes().as_slice());
let hash = hasher.finalize();
let mut result = Vec::with_capacity(hash.as_bytes().len() + 1);
let mut result = Vec::with_capacity(U32_LEN + 2);
result.push(KV_LOCK_DAV);
result.extend_from_slice(hash.as_bytes());
result.extend_from_slice(self.account_id?.to_be_bytes().as_slice());
result.push(u8::from(self.collection));
Some(result)
}
}

View file

@ -49,10 +49,39 @@ pub enum DavMethod {
pub(crate) enum DavError {
Parse(dav_proto::parser::Error),
Internal(trc::Error),
Condition(Condition),
Condition(DavErrorCondition),
Code(StatusCode),
}
struct DavErrorCondition {
pub code: StatusCode,
pub condition: Condition,
}
impl From<DavErrorCondition> for DavError {
fn from(value: DavErrorCondition) -> Self {
DavError::Condition(value)
}
}
impl From<Condition> for DavErrorCondition {
fn from(value: Condition) -> Self {
DavErrorCondition {
code: StatusCode::CONFLICT,
condition: value,
}
}
}
impl DavErrorCondition {
pub fn new(code: StatusCode, condition: impl Into<Condition>) -> Self {
DavErrorCondition {
code,
condition: condition.into(),
}
}
}
impl From<DavResource> for Collection {
fn from(value: DavResource) -> Self {
match value {

View file

@ -290,15 +290,13 @@ impl DavRequestHandler for Server {
}
}
Err(DavError::Parse(err)) => HttpResponse::new(StatusCode::BAD_REQUEST),
Err(DavError::Condition(condition)) => {
HttpResponse::new(StatusCode::PRECONDITION_FAILED)
.with_xml_body(
ErrorResponse::new(condition)
.with_namespace(resource)
.to_string(),
)
.with_no_cache()
}
Err(DavError::Condition(condition)) => HttpResponse::new(condition.code)
.with_xml_body(
ErrorResponse::new(condition.condition)
.with_namespace(resource)
.to_string(),
)
.with_no_cache(),
Err(DavError::Code(code)) => HttpResponse::new(code),
}
}

View file

@ -51,6 +51,11 @@ impl HttpResponse {
self
}
pub fn with_lock_token(mut self, token_uri: &str) -> Self {
self.builder = self.builder.header("Lock-Token", format!("<{token_uri}>"));
self
}
pub fn with_header<K, V>(mut self, name: K, value: V) -> Self
where
K: TryInto<HeaderName>,