Workaround for sqlx bug (#15)

This commit is contained in:
mdecimus 2023-07-24 11:48:23 +02:00
parent e8df912b27
commit 9725aa75f2
4 changed files with 35 additions and 20 deletions

1
Cargo.lock generated
View file

@ -1029,6 +1029,7 @@ dependencies = [
"argon2", "argon2",
"async-trait", "async-trait",
"bb8", "bb8",
"futures",
"ldap3", "ldap3",
"lru-cache", "lru-cache",
"mail-builder", "mail-builder",

View file

@ -29,6 +29,7 @@ scrypt = "0.11.0"
sha1 = "0.10.5" sha1 = "0.10.5"
sha2 = "0.10.6" sha2 = "0.10.6"
md5 = "0.7.0" md5 = "0.7.0"
futures = "0.3"
[dev-dependencies] [dev-dependencies]
tokio = { version = "1.23", features = ["full"] } tokio = { version = "1.23", features = ["full"] }

View file

@ -21,6 +21,7 @@
* for more details. * for more details.
*/ */
use futures::TryStreamExt;
use mail_send::Credentials; use mail_send::Credentials;
use sqlx::{any::AnyRow, Column, Row}; use sqlx::{any::AnyRow, Column, Row};
@ -48,18 +49,20 @@ impl Directory for SqlDirectory {
} }
async fn principal(&self, name: &str) -> crate::Result<Option<Principal>> { async fn principal(&self, name: &str) -> crate::Result<Option<Principal>> {
if let Some(row) = sqlx::query(&self.mappings.query_name) let result = sqlx::query(&self.mappings.query_name)
.bind(name) .bind(name)
.fetch_optional(&self.pool) .fetch(&self.pool)
.await? .try_next()
{ .await?;
if let Some(row) = result {
// Map row to principal // Map row to principal
let mut principal = self.mappings.row_to_principal(row)?; let mut principal = self.mappings.row_to_principal(row)?;
// Obtain members // Obtain members
principal.member_of = sqlx::query_scalar::<_, String>(&self.mappings.query_members) principal.member_of = sqlx::query_scalar::<_, String>(&self.mappings.query_members)
.bind(name) .bind(name)
.fetch_all(&self.pool) .fetch(&self.pool)
.try_collect::<Vec<_>>()
.await?; .await?;
// Check whether the user is a superuser // Check whether the user is a superuser
@ -81,22 +84,25 @@ impl Directory for SqlDirectory {
async fn emails_by_name(&self, name: &str) -> crate::Result<Vec<String>> { async fn emails_by_name(&self, name: &str) -> crate::Result<Vec<String>> {
sqlx::query_scalar::<_, String>(&self.mappings.query_emails) sqlx::query_scalar::<_, String>(&self.mappings.query_emails)
.bind(name) .bind(name)
.fetch_all(&self.pool) .fetch(&self.pool)
.try_collect::<Vec<_>>()
.await .await
.map_err(Into::into) .map_err(Into::into)
} }
async fn names_by_email(&self, address: &str) -> crate::Result<Vec<String>> { async fn names_by_email(&self, address: &str) -> crate::Result<Vec<String>> {
match sqlx::query_scalar::<_, String>(&self.mappings.query_recipients) let result = sqlx::query_scalar::<_, String>(&self.mappings.query_recipients)
.bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref()) .bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref())
.fetch_all(&self.pool) .fetch(&self.pool)
.await .try_collect::<Vec<_>>()
{ .await;
match result {
Ok(ids) if !ids.is_empty() => Ok(ids), Ok(ids) if !ids.is_empty() => Ok(ids),
Ok(_) if self.opt.catch_all => { Ok(_) if self.opt.catch_all => {
sqlx::query_scalar::<_, String>(&self.mappings.query_recipients) sqlx::query_scalar::<_, String>(&self.mappings.query_recipients)
.bind(to_catch_all_address(address)) .bind(to_catch_all_address(address))
.fetch_all(&self.pool) .fetch(&self.pool)
.try_collect::<Vec<_>>()
.await .await
.map_err(Into::into) .map_err(Into::into)
} }
@ -106,15 +112,17 @@ impl Directory for SqlDirectory {
} }
async fn rcpt(&self, address: &str) -> crate::Result<bool> { async fn rcpt(&self, address: &str) -> crate::Result<bool> {
match sqlx::query(&self.mappings.query_recipients) let result = sqlx::query(&self.mappings.query_recipients)
.bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref()) .bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref())
.fetch_optional(&self.pool) .fetch(&self.pool)
.await .try_next()
{ .await;
match result {
Ok(Some(_)) => Ok(true), Ok(Some(_)) => Ok(true),
Ok(None) if self.opt.catch_all => sqlx::query(&self.mappings.query_recipients) Ok(None) if self.opt.catch_all => sqlx::query(&self.mappings.query_recipients)
.bind(to_catch_all_address(address)) .bind(to_catch_all_address(address))
.fetch_optional(&self.pool) .fetch(&self.pool)
.try_next()
.await .await
.map(|id| id.is_some()) .map(|id| id.is_some())
.map_err(Into::into), .map_err(Into::into),
@ -126,7 +134,8 @@ impl Directory for SqlDirectory {
async fn vrfy(&self, address: &str) -> crate::Result<Vec<String>> { async fn vrfy(&self, address: &str) -> crate::Result<Vec<String>> {
sqlx::query_scalar::<_, String>(&self.mappings.query_verify) sqlx::query_scalar::<_, String>(&self.mappings.query_verify)
.bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref()) .bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref())
.fetch_all(&self.pool) .fetch(&self.pool)
.try_collect::<Vec<_>>()
.await .await
.map_err(Into::into) .map_err(Into::into)
} }
@ -134,7 +143,8 @@ impl Directory for SqlDirectory {
async fn expn(&self, address: &str) -> crate::Result<Vec<String>> { async fn expn(&self, address: &str) -> crate::Result<Vec<String>> {
sqlx::query_scalar::<_, String>(&self.mappings.query_expand) sqlx::query_scalar::<_, String>(&self.mappings.query_expand)
.bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref()) .bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref())
.fetch_all(&self.pool) .fetch(&self.pool)
.try_collect::<Vec<_>>()
.await .await
.map_err(Into::into) .map_err(Into::into)
} }
@ -146,7 +156,8 @@ impl Directory for SqlDirectory {
q = q.bind(param); q = q.bind(param);
} }
q.fetch_optional(&self.pool) q.fetch(&self.pool)
.try_next()
.await .await
.map(|r| r.is_some()) .map(|r| r.is_some())
.map_err(Into::into) .map_err(Into::into)
@ -155,7 +166,8 @@ impl Directory for SqlDirectory {
async fn is_local_domain(&self, domain: &str) -> crate::Result<bool> { async fn is_local_domain(&self, domain: &str) -> crate::Result<bool> {
sqlx::query(&self.mappings.query_domains) sqlx::query(&self.mappings.query_domains)
.bind(domain) .bind(domain)
.fetch_optional(&self.pool) .fetch(&self.pool)
.try_next()
.await .await
.map(|id| id.is_some()) .map(|id| id.is_some())
.map_err(Into::into) .map_err(Into::into)

View file

@ -37,6 +37,7 @@ const CONFIG: &str = r#"
[directory."sql"] [directory."sql"]
type = "sql" type = "sql"
address = "sqlite::memory:" address = "sqlite::memory:"
#address = "mysql://root:secret@localhost:3306/stalwart?ssl_mode=disabled"
[directory."sql".options] [directory."sql".options]
catch-all = true catch-all = true