Support for milter on all SMTP stages (closes #183)

This commit is contained in:
mdecimus 2024-06-18 15:09:50 +02:00
parent f930cf59ef
commit aa4a11baf7
10 changed files with 430 additions and 81 deletions

View file

@ -3,6 +3,7 @@ use std::{
time::Duration,
};
use ahash::AHashSet;
use smtp_proto::*;
use utils::config::{utils::ParseValue, Config};
@ -30,6 +31,9 @@ pub struct SessionConfig {
pub data: Data,
pub extensions: Extensions,
pub mta_sts_policy: Option<Policy>,
pub milters: Vec<Milter>,
pub jmilters: Vec<JMilter>,
}
#[derive(Default, Debug, Clone)]
@ -114,7 +118,6 @@ pub enum AddressMapping {
pub struct Data {
pub script: IfBlock,
pub pipe_commands: Vec<Pipe>,
pub milters: Vec<Milter>,
// Limits
pub max_messages: IfBlock,
@ -154,6 +157,7 @@ pub struct Milter {
pub protocol_version: MilterVersion,
pub flags_actions: Option<u32>,
pub flags_protocol: Option<u32>,
pub run_on_stage: AHashSet<Stage>,
}
#[derive(Clone, Copy)]
@ -162,6 +166,26 @@ pub enum MilterVersion {
V6,
}
#[derive(Clone)]
pub struct JMilter {
pub enable: IfBlock,
pub url: String,
pub timeout: Duration,
pub tls_allow_invalid_certs: bool,
pub tempfail_on_error: bool,
pub run_on_stage: AHashSet<Stage>,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub enum Stage {
Connect,
Ehlo,
Auth,
Mail,
Rcpt,
Data,
}
impl SessionConfig {
pub fn parse(config: &mut Config) -> Self {
let has_conn_vars = TokenMap::default().with_variables(CONNECTION_VARS);
@ -174,13 +198,20 @@ impl SessionConfig {
let mut session = SessionConfig::default();
session.rcpt.catch_all = AddressMapping::parse(config, "session.rcpt.catch-all");
session.rcpt.subaddressing = AddressMapping::parse(config, "session.rcpt.sub-addressing");
session.data.milters = config
.sub_keys("session.data.milter", "")
session.milters = config
.sub_keys("session.milter", "")
.map(|s| s.to_string())
.collect::<Vec<_>>()
.into_iter()
.filter_map(|id| parse_milter(config, &id, &has_rcpt_vars))
.collect();
session.jmilters = config
.sub_keys("session.jmilter", "")
.map(|s| s.to_string())
.collect::<Vec<_>>()
.into_iter()
.filter_map(|id| parse_jmilter(config, &id, &has_rcpt_vars))
.collect();
session.data.pipe_commands = config
.sub_keys("session.data.pipe", "")
.map(|s| s.to_string())
@ -479,19 +510,19 @@ fn parse_pipe(config: &mut Config, id: &str, token_map: &TokenMap) -> Option<Pip
fn parse_milter(config: &mut Config, id: &str, token_map: &TokenMap) -> Option<Milter> {
let hostname = config
.value_require(("session.data.milter", id, "hostname"))?
.value_require(("session.milter", id, "hostname"))?
.to_string();
let port = config.property_require(("session.data.milter", id, "port"))?;
let port = config.property_require(("session.milter", id, "port"))?;
Some(Milter {
enable: IfBlock::try_parse(config, ("session.data.milter", id, "enable"), token_map)
enable: IfBlock::try_parse(config, ("session.milter", id, "enable"), token_map)
.unwrap_or_else(|| {
IfBlock::new::<()>(format!("session.data.milter.{id}.enable"), [], "false")
IfBlock::new::<()>(format!("session.milter.{id}.enable"), [], "false")
}),
addrs: format!("{}:{}", hostname, port)
.to_socket_addrs()
.map_err(|err| {
config.new_build_error(
("session.data.milter", id, "hostname"),
("session.milter", id, "hostname"),
format!("Unable to resolve milter hostname {hostname}: {err}"),
)
})
@ -500,51 +531,105 @@ fn parse_milter(config: &mut Config, id: &str, token_map: &TokenMap) -> Option<M
hostname,
port,
timeout_connect: config
.property_or_default(("session.data.milter", id, "timeout.connect"), "30s")
.property_or_default(("session.milter", id, "timeout.connect"), "30s")
.unwrap_or_else(|| Duration::from_secs(30)),
timeout_command: config
.property_or_default(("session.data.milter", id, "timeout.command"), "30s")
.property_or_default(("session.milter", id, "timeout.command"), "30s")
.unwrap_or_else(|| Duration::from_secs(30)),
timeout_data: config
.property_or_default(("session.data.milter", id, "timeout.data"), "60s")
.property_or_default(("session.milter", id, "timeout.data"), "60s")
.unwrap_or_else(|| Duration::from_secs(60)),
tls: config
.property_or_default(("session.data.milter", id, "tls"), "false")
.property_or_default(("session.milter", id, "tls"), "false")
.unwrap_or_default(),
tls_allow_invalid_certs: config
.property_or_default(("session.data.milter", id, "allow-invalid-certs"), "false")
.property_or_default(("session.milter", id, "allow-invalid-certs"), "false")
.unwrap_or_default(),
tempfail_on_error: config
.property_or_default(
("session.data.milter", id, "options.tempfail-on-error"),
"true",
)
.property_or_default(("session.milter", id, "options.tempfail-on-error"), "true")
.unwrap_or(true),
max_frame_len: config
.property_or_default(
("session.data.milter", id, "options.max-response-size"),
("session.milter", id, "options.max-response-size"),
"52428800",
)
.unwrap_or(52428800),
protocol_version: match config
.property_or_default::<u32>(("session.data.milter", id, "options.version"), "6")
.property_or_default::<u32>(("session.milter", id, "options.version"), "6")
.unwrap_or(6)
{
6 => MilterVersion::V6,
2 => MilterVersion::V2,
v => {
config.new_parse_error(
("session.data.milter", id, "options.version"),
("session.milter", id, "options.version"),
format!("Unsupported milter protocol version {v}"),
);
MilterVersion::V6
}
},
flags_actions: config.property(("session.data.milter", id, "options.flags.actions")),
flags_protocol: config.property(("session.data.milter", id, "options.flags.protocol")),
flags_actions: config.property(("session.milter", id, "options.flags.actions")),
flags_protocol: config.property(("session.milter", id, "options.flags.protocol")),
run_on_stage: parse_stages(config, "session.milter", id),
})
}
fn parse_jmilter(config: &mut Config, id: &str, token_map: &TokenMap) -> Option<JMilter> {
Some(JMilter {
enable: IfBlock::try_parse(config, ("session.jmilter", id, "enable"), token_map)
.unwrap_or_else(|| {
IfBlock::new::<()>(format!("session.jmilter.{id}.enable"), [], "false")
}),
url: config
.value_require(("session.jmilter", id, "hostname"))?
.to_string(),
timeout: config
.property_or_default(("session.jmilter", id, "timeout"), "30s")
.unwrap_or_else(|| Duration::from_secs(30)),
tls_allow_invalid_certs: config
.property_or_default(("session.jmilter", id, "allow-invalid-certs"), "false")
.unwrap_or_default(),
tempfail_on_error: config
.property_or_default(("session.jmilter", id, "options.tempfail-on-error"), "true")
.unwrap_or(true),
run_on_stage: parse_stages(config, "session.jmilter", id),
})
}
fn parse_stages(config: &mut Config, prefix: &str, id: &str) -> AHashSet<Stage> {
let mut stages = AHashSet::default();
let mut invalid = Vec::new();
for (_, value) in config.values((prefix, id, "stages")) {
let value = value.to_ascii_lowercase();
let state = match value.as_str() {
"connect" => Stage::Connect,
"ehlo" => Stage::Ehlo,
"auth" => Stage::Auth,
"mail" => Stage::Mail,
"rcpt" => Stage::Rcpt,
"data" => Stage::Data,
_ => {
invalid.push(value);
continue;
}
};
stages.insert(state);
}
if !invalid.is_empty() {
config.new_parse_error(
(prefix, id, "stages"),
format!("Invalid stages: {}", invalid.join(", ")),
);
}
if stages.is_empty() {
stages.insert(Stage::Data);
}
stages
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
@ -640,7 +725,6 @@ impl Default for SessionConfig {
"'track-replies'",
),
pipe_commands: Default::default(),
milters: Default::default(),
max_messages: IfBlock::new::<()>("session.data.limits.messages", [], "10"),
max_message_size: IfBlock::new::<()>("session.data.limits.size", [], "104857600"),
max_received_headers: IfBlock::new::<()>(
@ -716,6 +800,8 @@ impl Default for SessionConfig {
),
},
mta_sts_policy: None,
milters: Default::default(),
jmilters: Default::default(),
}
}
}

View file

@ -29,7 +29,9 @@ use std::{
};
use common::{
config::smtp::auth::VerifyStrategy, listener::SessionStream, scripts::ScriptModification,
config::smtp::{auth::VerifyStrategy, session::Stage},
listener::SessionStream,
scripts::ScriptModification,
};
use mail_auth::{
common::{headers::HeaderWriter, verify::VerifySignature},
@ -398,7 +400,7 @@ impl<T: SessionStream> Session<T> {
}
// Run Milter filters
let mut edited_message = match self.run_milters(&auth_message).await {
let mut edited_message = match self.run_milters(Stage::Data, (&auth_message).into()).await {
Ok(modifications) => {
if !modifications.is_empty() {
tracing::debug!(

View file

@ -24,7 +24,10 @@
use std::time::{Duration, SystemTime};
use crate::{core::Session, scripts::ScriptResult};
use common::{config::smtp::session::Mechanism, listener::SessionStream};
use common::{
config::smtp::session::{Mechanism, Stage},
listener::SessionStream,
};
use mail_auth::spf::verify::HasLabels;
use smtp_proto::*;
@ -102,6 +105,20 @@ impl<T: SessionStream> Session<T> {
}
}
// Milter filtering
if let Err(message) = self.run_milters(Stage::Ehlo, None).await {
tracing::info!(parent: &self.span,
context = "milter",
event = "reject",
domain = &self.data.helo_domain,
reason = std::str::from_utf8(message.as_ref()).unwrap_or_default());
self.data.mail_from = None;
self.data.helo_domain = prev_helo_domain;
self.data.spf_ehlo = None;
return self.write(message.as_ref()).await;
}
tracing::debug!(parent: &self.span,
context = "ehlo",
event = "ehlo",

View file

@ -0,0 +1,179 @@
use ahash::AHashMap;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct Request {
context: Context,
#[serde(skip_serializing_if = "Option::is_none")]
envelope: Option<Envelope>,
#[serde(skip_serializing_if = "Option::is_none")]
message: Option<Message>,
}
#[derive(Serialize, Deserialize)]
pub struct Context {
stage: Stage,
client: Client,
#[serde(skip_serializing_if = "Option::is_none")]
sasl: Option<Sasl>,
#[serde(skip_serializing_if = "Option::is_none")]
tls: Option<Tls>,
server: Server,
#[serde(skip_serializing_if = "Option::is_none")]
queue: Option<Queue>,
protocol: Protocol,
}
#[derive(Serialize, Deserialize)]
pub struct Sasl {
login: String,
method: String,
}
#[derive(Serialize, Deserialize)]
pub struct Client {
ip: String,
port: u16,
ptr: Option<String>,
helo: Option<String>,
#[serde(rename = "activeConnections")]
active_connections: u32,
}
#[derive(Serialize, Deserialize)]
pub struct Tls {
version: String,
cipher: String,
#[serde(rename = "cipherBits")]
#[serde(skip_serializing_if = "Option::is_none")]
bits: Option<u16>,
#[serde(rename = "certIssuer")]
#[serde(skip_serializing_if = "Option::is_none")]
issuer: Option<String>,
#[serde(rename = "certSubject")]
#[serde(skip_serializing_if = "Option::is_none")]
subject: Option<String>,
}
#[derive(Serialize, Deserialize)]
pub struct Server {
name: Option<String>,
port: u16,
ip: Option<String>,
}
#[derive(Serialize, Deserialize)]
pub struct Queue {
id: String,
}
#[derive(Serialize, Deserialize)]
pub struct Protocol {
version: String,
}
#[derive(Serialize, Deserialize)]
pub enum Stage {
#[serde(rename = "connect")]
Connect,
#[serde(rename = "ehlo")]
Ehlo,
#[serde(rename = "auth")]
Auth,
#[serde(rename = "mail")]
Mail,
#[serde(rename = "rcpt")]
Rcpt,
#[serde(rename = "data")]
Data,
}
#[derive(Serialize, Deserialize)]
pub struct Address {
address: String,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<AHashMap<String, String>>,
}
#[derive(Serialize, Deserialize)]
pub struct Envelope {
from: Address,
to: Vec<Address>,
}
#[derive(Serialize, Deserialize)]
pub struct Message {
headers: Vec<(String, String)>,
#[serde(skip_serializing_if = "Vec::is_empty")]
#[serde(rename = "serverHeaders")]
server_headers: Vec<(String, String)>,
body: String,
size: usize,
}
#[derive(Serialize, Deserialize)]
pub struct Response {
action: Action,
modifications: Vec<Modification>,
}
#[derive(Serialize, Deserialize)]
pub enum Action {
#[serde(rename = "accept")]
Accept,
#[serde(rename = "discard")]
Discard,
#[serde(rename = "reject")]
Reject,
#[serde(rename = "tempFail")]
Tempfail,
#[serde(rename = "shutdown")]
Shutdown,
#[serde(rename = "connectionFailure")]
ConnectionFailure,
#[serde(rename = "replyCode")]
ReplyCode,
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Modification {
#[serde(rename = "changeFrom")]
ChangeFrom {
value: String,
#[serde(default)]
parameters: AHashMap<String, String>,
},
#[serde(rename = "addRecipient")]
AddRecipient {
value: String,
#[serde(default)]
parameters: AHashMap<String, String>,
},
#[serde(rename = "deleteRecipient")]
DeleteRecipient { value: String },
#[serde(rename = "replaceBody")]
ReplaceBody { value: String },
#[serde(rename = "addHeader")]
AddHeader { name: String, value: String },
#[serde(rename = "insertHeader")]
InsertHeader {
index: i32,
name: String,
value: String,
},
#[serde(rename = "changeHeader")]
ChangeHeader {
index: i32,
name: String,
value: String,
},
#[serde(rename = "deleteHeader")]
DeleteHeader {
#[serde(default)]
index: Option<i32>,
name: String,
},
#[serde(rename = "quarantine")]
Quarantine { value: String },
}

View file

@ -23,7 +23,7 @@
use std::time::{Duration, SystemTime};
use common::{listener::SessionStream, scripts::ScriptModification};
use common::{config::smtp::session::Stage, listener::SessionStream, scripts::ScriptModification};
use mail_auth::{IprevOutput, IprevResult, SpfOutput, SpfResult};
use smtp_proto::{MailFrom, MtPriority, MAIL_BY_NOTIFY, MAIL_BY_RETURN, MAIL_REQUIRETLS};
use utils::config::Rate;
@ -167,6 +167,18 @@ impl<T: SessionStream> Session<T> {
}
}
// Milter filtering
if let Err(message) = self.run_milters(Stage::Mail, None).await {
tracing::info!(parent: &self.span,
context = "milter",
event = "reject",
address = &self.data.mail_from.as_ref().unwrap().address,
reason = std::str::from_utf8(message.as_ref()).unwrap_or_default());
self.data.mail_from = None;
return self.write(message.as_ref()).await;
}
// Address rewriting
if let Some(new_address) = self
.core

View file

@ -23,7 +23,10 @@
use std::borrow::Cow;
use common::{config::smtp::session::Milter, listener::SessionStream};
use common::{
config::smtp::session::{Milter, Stage},
listener::SessionStream,
};
use mail_auth::AuthenticatedMessage;
use smtp_proto::request::parser::Rfc5321Parser;
use tokio::io::{AsyncRead, AsyncWrite};
@ -45,21 +48,23 @@ enum Rejection {
impl<T: SessionStream> Session<T> {
pub async fn run_milters(
&self,
message: &AuthenticatedMessage<'_>,
stage: Stage,
message: Option<&AuthenticatedMessage<'_>>,
) -> Result<Vec<Modification>, Cow<'static, [u8]>> {
let milters = &self.core.core.smtp.session.data.milters;
let milters = &self.core.core.smtp.session.milters;
if milters.is_empty() {
return Ok(Vec::new());
}
let mut modifications = Vec::new();
for milter in milters {
if !self
.core
.core
.eval_if(&milter.enable, self)
.await
.unwrap_or(false)
if !milter.run_on_stage.contains(&stage)
|| !self
.core
.core
.eval_if(&milter.enable, self)
.await
.unwrap_or(false)
{
continue;
}
@ -132,7 +137,7 @@ impl<T: SessionStream> Session<T> {
async fn connect_and_run(
&self,
milter: &Milter,
message: &AuthenticatedMessage<'_>,
message: Option<&AuthenticatedMessage<'_>>,
) -> Result<Vec<Modification>, Rejection> {
// Build client
let client = MilterClient::connect(milter, self.span.clone()).await?;
@ -159,7 +164,7 @@ impl<T: SessionStream> Session<T> {
async fn run<S: AsyncRead + AsyncWrite + Unpin>(
&self,
mut client: MilterClient<S>,
message: &AuthenticatedMessage<'_>,
message: Option<&AuthenticatedMessage<'_>>,
) -> Result<Vec<Modification>, Rejection> {
// Option negotiation
client.init().await?;
@ -187,65 +192,74 @@ impl<T: SessionStream> Session<T> {
.assert_continue()?;
// EHLO/HELO
let (tls_version, tls_ciper) = self.stream.tls_version_and_cipher();
let (tls_version, tls_cipher) = self.stream.tls_version_and_cipher();
client
.helo(
&self.data.helo_domain,
Macros::new()
.with_cipher(tls_ciper.as_ref())
.with_cipher(tls_cipher.as_ref())
.with_tls_version(tls_version.as_ref()),
)
.await?
.assert_continue()?;
// Mail from
let addr = &self.data.mail_from.as_ref().unwrap().address_lcase;
client
.mail_from(
&format!("<{addr}>"),
None::<&[&str]>,
Macros::new()
.with_mail_address(addr)
.with_sasl_login_name(&self.data.authenticated_as),
)
.await?
.assert_continue()?;
// Rcpt to
for rcpt in &self.data.rcpt_to {
if let Some(mail_from) = &self.data.mail_from {
let addr = &mail_from.address_lcase;
client
.rcpt_to(
&format!("<{}>", rcpt.address_lcase),
.mail_from(
&format!("<{addr}>"),
None::<&[&str]>,
Macros::new().with_rcpt_address(&rcpt.address_lcase),
Macros::new()
.with_mail_address(addr)
.with_sasl_login_name(&self.data.authenticated_as),
)
.await?
.assert_continue()?;
// Rcpt to
for rcpt in &self.data.rcpt_to {
client
.rcpt_to(
&format!("<{}>", rcpt.address_lcase),
None::<&[&str]>,
Macros::new().with_rcpt_address(&rcpt.address_lcase),
)
.await?
.assert_continue()?;
}
}
// Data
client.data().await?.assert_continue()?;
if let Some(message) = message {
// Data
client.data().await?.assert_continue()?;
// Headers
client
.headers(message.raw_parsed_headers().iter().map(|(k, v)| {
(
std::str::from_utf8(k).unwrap_or_default(),
std::str::from_utf8(v).unwrap_or_default(),
)
}))
.await?
.assert_continue()?;
// Headers
client
.headers(message.raw_parsed_headers().iter().map(|(k, v)| {
(
std::str::from_utf8(k).unwrap_or_default(),
std::str::from_utf8(v).unwrap_or_default(),
)
}))
.await?
.assert_continue()?;
// Message body
let (action, modifications) = client.body(message.raw_message()).await?;
action.assert_continue()?;
// Message body
let (action, modifications) = client.body(message.raw_message()).await?;
action.assert_continue()?;
// Quit
let _ = client.quit().await;
// Quit
let _ = client.quit().await;
// Return modifications
Ok(modifications)
// Return modifications
Ok(modifications)
} else {
// Quit
let _ = client.quit().await;
Ok(Vec::new())
}
}
}

View file

@ -30,6 +30,7 @@ use mail_auth::{
pub mod auth;
pub mod data;
pub mod ehlo;
pub mod jmilter;
pub mod mail;
pub mod milter;
pub mod rcpt;

View file

@ -21,7 +21,7 @@
* for more details.
*/
use common::{listener::SessionStream, scripts::ScriptModification};
use common::{config::smtp::session::Stage, listener::SessionStream, scripts::ScriptModification};
use smtp_proto::{
RcptTo, RCPT_NOTIFY_DELAY, RCPT_NOTIFY_FAILURE, RCPT_NOTIFY_NEVER, RCPT_NOTIFY_SUCCESS,
};
@ -85,7 +85,17 @@ impl<T: SessionStream> Session<T> {
.and_then(|name| self.core.core.get_sieve_script(&name))
.cloned();
if rcpt_script.is_some() || !self.core.core.smtp.session.rcpt.rewrite.is_empty() {
if rcpt_script.is_some()
|| !self.core.core.smtp.session.rcpt.rewrite.is_empty()
|| self
.core
.core
.smtp
.session
.milters
.iter()
.any(|m| m.run_on_stage.contains(&Stage::Rcpt))
{
// Sieve filtering
if let Some(script) = rcpt_script {
match self
@ -121,6 +131,18 @@ impl<T: SessionStream> Session<T> {
}
}
// Milter filtering
if let Err(message) = self.run_milters(Stage::Rcpt, None).await {
tracing::info!(parent: &self.span,
context = "milter",
event = "reject",
address = self.data.rcpt_to.last().unwrap().address,
reason = std::str::from_utf8(message.as_ref()).unwrap_or_default());
self.data.rcpt_to.pop();
return self.write(message.as_ref()).await;
}
// Address rewriting
if let Some(new_address) = self
.core

View file

@ -23,7 +23,10 @@
use std::time::Instant;
use common::listener::{self, SessionManager, SessionStream};
use common::{
config::smtp::session::Stage,
listener::{self, SessionManager, SessionStream},
};
use tokio_rustls::server::TlsStream;
use crate::{
@ -118,6 +121,16 @@ impl<T: SessionStream> Session<T> {
}
}
// Milter filtering
if let Err(message) = self.run_milters(Stage::Connect, None).await {
tracing::debug!(parent: &self.span,
context = "connext",
event = "milter-reject",
reason = std::str::from_utf8(message.as_ref()).unwrap_or_default());
let _ = self.write(message.as_ref()).await;
return false;
}
// Obtain hostname
self.hostname = self
.core

View file

@ -23,8 +23,9 @@
use std::{fs, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
use ahash::AHashSet;
use common::{
config::smtp::session::{Milter, MilterVersion},
config::smtp::session::{Milter, MilterVersion, Stage},
expr::if_block::IfBlock,
Core,
};
@ -73,7 +74,7 @@ path = "{TMP}/queue.db"
[session.rcpt]
relay = true
[[session.data.milter]]
[[session.milter]]
hostname = "127.0.0.1"
port = 9332
#port = 11332
@ -81,6 +82,7 @@ port = 9332
enable = true
options.version = 6
tls = false
stages = ["data"]
"#;
@ -419,6 +421,7 @@ async fn milter_client_test() {
protocol_version: MilterVersion::V6,
flags_actions: None,
flags_protocol: None,
run_on_stage: AHashSet::from([Stage::Data]),
},
tracing::span!(tracing::Level::TRACE, "hi"),
)