IMAP: Argument buffer parser checked push

This commit is contained in:
mdecimus 2025-09-30 12:34:59 +02:00
parent 2f6cfbb6e6
commit 4a96357b2b

View file

@ -8,8 +8,6 @@ use super::{ResponseCode, ResponseType};
use compact_str::{CompactString, format_compact};
use std::fmt::Display;
const QUOTED_ARG_MAX_LEN: usize = 4096;
#[derive(Debug, Clone)]
pub enum Error {
NeedsMoreData,
@ -65,7 +63,7 @@ pub enum State {
}
pub struct Receiver<T: CommandParser> {
buf: Vec<u8>,
buf: ArgumentBuffer,
pub request: Request<T>,
pub state: State,
pub max_request_size: usize,
@ -73,6 +71,12 @@ pub struct Receiver<T: CommandParser> {
pub start_state: State,
}
const ARG_MAX_LEN: usize = 4096;
struct ArgumentBuffer {
buf: Vec<u8>,
}
impl<T: CommandParser> Receiver<T> {
pub fn new() -> Self {
Receiver {
@ -104,7 +108,7 @@ impl<T: CommandParser> Receiver<T> {
},
message,
);
self.buf = Vec::with_capacity(10);
self.buf = ArgumentBuffer::default();
self.state = self.start_state;
self.current_request_size = 0;
err
@ -119,8 +123,7 @@ impl<T: CommandParser> Receiver<T> {
self.max_request_size
)));
}
self.request.tokens.push(Token::Argument(self.buf.clone()));
self.buf.clear();
self.request.tokens.push(Token::Argument(self.buf.take()));
} else if in_quote {
self.request.tokens.push(Token::Nil);
}
@ -145,18 +148,18 @@ impl<T: CommandParser> Receiver<T> {
match self.state {
State::Start => {
if !ch.is_ascii_whitespace() {
self.buf.push(ch);
// SAFETY: This called just once
self.buf.push_unchecked(ch);
self.state = State::Tag;
}
}
State::Tag => match ch {
b' ' => {
if !self.buf.is_empty() {
self.request.tag = String::from_utf8(std::mem::replace(
&mut self.buf,
Vec::with_capacity(10),
))
.map_err(|_| self.error_reset("Tag is not a valid UTF-8 string."))?;
self.request.tag =
String::from_utf8(self.buf.take()).map_err(|_| {
self.error_reset("Tag is not a valid UTF-8 string.")
})?;
self.state = State::Command { is_uid: false };
}
}
@ -164,35 +167,32 @@ impl<T: CommandParser> Receiver<T> {
b'\n' => {
return Err(self.error_reset(format_compact!(
"Missing command after tag {:?}, found CRLF instead.",
std::str::from_utf8(&self.buf).unwrap_or_default()
self.buf.as_str()
)));
}
_ => {
if self.buf.len() < 128 {
self.buf.push(ch);
} else {
return Err(self.error_reset("Tag too long."));
}
self.buf.push_checked(ch, 128).map_err(|_| {
self.error_reset("Tag exceeds maximum length of 128 characters.")
})?;
}
},
State::Command { is_uid } => {
if ch.is_ascii_alphanumeric() {
if self.buf.len() < 15 {
self.buf.push(ch.to_ascii_uppercase());
} else {
return Err(self.error_reset("Command too long"));
}
self.buf
.push_checked(ch.to_ascii_uppercase(), 15)
.map_err(|_| {
self.error_reset("Command exceeds maximum length of 15 characters.")
})?;
} else if ch.is_ascii_whitespace() {
if !self.buf.is_empty() {
if !self.buf.eq_ignore_ascii_case(b"UID") {
self.request.command =
T::parse(&self.buf, is_uid).ok_or_else(|| {
let command =
String::from_utf8_lossy(&self.buf).into_owned();
self.error_reset(format_compact!(
if !self.buf.as_ref().eq_ignore_ascii_case(b"UID") {
self.request.command = T::parse(self.buf.as_ref(), is_uid)
.ok_or_else(|| {
let err = format_compact!(
"Unrecognized command '{}'.",
command
))
String::from_utf8_lossy(self.buf.as_ref())
);
self.error_reset(err)
})?;
self.buf.clear();
if ch != b'\n' {
@ -268,7 +268,9 @@ impl<T: CommandParser> Receiver<T> {
self.state = State::Argument { last_ch: ch };
}
_ => {
self.buf.push(ch);
self.buf.push_checked(ch, ARG_MAX_LEN).map_err(|_| {
self.error_reset("Argument exceeds maximum length of 4096 bytes.")
})?;
self.state = State::Argument { last_ch: ch };
}
},
@ -277,20 +279,18 @@ impl<T: CommandParser> Receiver<T> {
if !escaped {
self.push_argument(true)?;
self.state = State::Argument { last_ch: b' ' };
} else if self.buf.len() < QUOTED_ARG_MAX_LEN {
self.buf.push(ch);
self.state = State::ArgumentQuoted { escaped: false };
} else {
return Err(self.error_reset("Quoted argument too long."));
self.buf
.push_checked(ch, ARG_MAX_LEN)
.map_err(|_| self.error_reset("Quoted argument too long."))?;
self.state = State::ArgumentQuoted { escaped: false };
}
}
b'\\' => {
if escaped {
if self.buf.len() < QUOTED_ARG_MAX_LEN {
self.buf.push(ch);
} else {
return Err(self.error_reset("Quoted argument too long."));
}
self.buf
.push_checked(ch, ARG_MAX_LEN)
.map_err(|_| self.error_reset("Quoted argument too long."))?;
}
self.state = State::ArgumentQuoted { escaped: !escaped };
}
@ -298,25 +298,21 @@ impl<T: CommandParser> Receiver<T> {
return Err(self.error_reset("Unterminated quoted argument."));
}
_ => {
if self.buf.len() < QUOTED_ARG_MAX_LEN {
if escaped {
self.buf.push(b'\\');
}
self.buf.push(ch);
self.state = State::ArgumentQuoted { escaped: false };
} else {
return Err(self.error_reset("Quoted argument too long."));
if escaped {
// SAFETY: We check the size below
self.buf.push_unchecked(b'\\');
}
self.buf
.push_checked(ch, ARG_MAX_LEN)
.map_err(|_| self.error_reset("Quoted argument too long."))?;
self.state = State::ArgumentQuoted { escaped: false };
}
},
State::Literal { non_sync } => {
match ch {
b'}' => {
if !self.buf.is_empty() {
let size = std::str::from_utf8(&self.buf)
.unwrap()
.parse::<u32>()
.map_err(|_| {
let size = self.buf.as_str().parse::<u32>().map_err(|_| {
self.error_reset("Literal size is not a valid number.")
})?;
if self.current_request_size + size as usize > self.max_request_size
@ -327,7 +323,8 @@ impl<T: CommandParser> Receiver<T> {
)));
}
self.state = State::LiteralSeek { size, non_sync };
self.buf = Vec::with_capacity(size as usize);
self.buf.resize_buffer(size as usize);
self.buf.clear();
} else {
return Err(self.error_reset("Invalid empty literal."));
}
@ -341,7 +338,9 @@ impl<T: CommandParser> Receiver<T> {
}
_ if ch.is_ascii_digit() => {
if !non_sync {
self.buf.push(ch);
self.buf.push_checked(ch, 15).map_err(|_| {
self.error_reset("Literal size exceeds maximum of 15 digits.")
})?;
} else {
// Digit found after non-sync '+' flag
return Err(self.error_reset("Invalid literal."));
@ -373,7 +372,9 @@ impl<T: CommandParser> Receiver<T> {
}
}
State::LiteralData { remaining } => {
self.buf.push(ch);
// SAFETY: We checked the size before entering this state
self.buf.push_unchecked(ch);
if remaining > 1 {
self.state = State::LiteralData {
remaining: remaining - 1,
@ -390,6 +391,61 @@ impl<T: CommandParser> Receiver<T> {
}
}
impl ArgumentBuffer {
pub fn new() -> Self {
ArgumentBuffer {
buf: Vec::with_capacity(10),
}
}
pub fn resize_buffer(&mut self, size: usize) {
if self.buf.capacity() < size {
self.buf.reserve(size - self.buf.capacity());
}
}
#[inline(always)]
pub fn push_checked(&mut self, byte: u8, limit: usize) -> Result<(), ()> {
if self.buf.len() < limit {
self.buf.push(byte);
Ok(())
} else {
Err(())
}
}
#[inline(always)]
pub fn push_unchecked(&mut self, byte: u8) {
self.buf.push(byte);
}
pub fn take(&mut self) -> Vec<u8> {
let buf = self.buf.clone();
self.buf.clear();
buf
}
#[inline(always)]
pub fn len(&self) -> usize {
self.buf.len()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.buf.is_empty()
}
#[inline(always)]
pub fn clear(&mut self) {
self.buf.clear();
}
#[inline(always)]
pub fn as_str(&self) -> &str {
std::str::from_utf8(&self.buf).unwrap_or_default()
}
}
impl Token {
pub fn unwrap_string(self) -> crate::parser::Result<String> {
match self {
@ -450,6 +506,18 @@ impl Token {
}
}
impl AsRef<[u8]> for ArgumentBuffer {
fn as_ref(&self) -> &[u8] {
&self.buf
}
}
impl Default for ArgumentBuffer {
fn default() -> Self {
Self::new()
}
}
impl Display for Token {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str(&String::from_utf8_lossy(self.as_bytes()))
@ -487,7 +555,7 @@ impl Error {
impl<T: CommandParser> Default for Receiver<T> {
fn default() -> Self {
Self {
buf: Vec::with_capacity(10),
buf: Default::default(),
request: Default::default(),
state: State::Start,
start_state: State::Start,