Workaround for sqlx bug (#15)

This commit is contained in:
mdecimus 2023-07-24 11:48:23 +02:00
parent e8df912b27
commit 9725aa75f2
4 changed files with 35 additions and 20 deletions

1
Cargo.lock generated
View file

@ -1029,6 +1029,7 @@ dependencies = [
"argon2",
"async-trait",
"bb8",
"futures",
"ldap3",
"lru-cache",
"mail-builder",

View file

@ -29,6 +29,7 @@ scrypt = "0.11.0"
sha1 = "0.10.5"
sha2 = "0.10.6"
md5 = "0.7.0"
futures = "0.3"
[dev-dependencies]
tokio = { version = "1.23", features = ["full"] }

View file

@ -21,6 +21,7 @@
* for more details.
*/
use futures::TryStreamExt;
use mail_send::Credentials;
use sqlx::{any::AnyRow, Column, Row};
@ -48,18 +49,20 @@ impl Directory for SqlDirectory {
}
async fn principal(&self, name: &str) -> crate::Result<Option<Principal>> {
if let Some(row) = sqlx::query(&self.mappings.query_name)
let result = sqlx::query(&self.mappings.query_name)
.bind(name)
.fetch_optional(&self.pool)
.await?
{
.fetch(&self.pool)
.try_next()
.await?;
if let Some(row) = result {
// Map row to principal
let mut principal = self.mappings.row_to_principal(row)?;
// Obtain members
principal.member_of = sqlx::query_scalar::<_, String>(&self.mappings.query_members)
.bind(name)
.fetch_all(&self.pool)
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await?;
// Check whether the user is a superuser
@ -81,22 +84,25 @@ impl Directory for SqlDirectory {
async fn emails_by_name(&self, name: &str) -> crate::Result<Vec<String>> {
sqlx::query_scalar::<_, String>(&self.mappings.query_emails)
.bind(name)
.fetch_all(&self.pool)
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await
.map_err(Into::into)
}
async fn names_by_email(&self, address: &str) -> crate::Result<Vec<String>> {
match sqlx::query_scalar::<_, String>(&self.mappings.query_recipients)
let result = sqlx::query_scalar::<_, String>(&self.mappings.query_recipients)
.bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref())
.fetch_all(&self.pool)
.await
{
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await;
match result {
Ok(ids) if !ids.is_empty() => Ok(ids),
Ok(_) if self.opt.catch_all => {
sqlx::query_scalar::<_, String>(&self.mappings.query_recipients)
.bind(to_catch_all_address(address))
.fetch_all(&self.pool)
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await
.map_err(Into::into)
}
@ -106,15 +112,17 @@ impl Directory for SqlDirectory {
}
async fn rcpt(&self, address: &str) -> crate::Result<bool> {
match sqlx::query(&self.mappings.query_recipients)
let result = sqlx::query(&self.mappings.query_recipients)
.bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref())
.fetch_optional(&self.pool)
.await
{
.fetch(&self.pool)
.try_next()
.await;
match result {
Ok(Some(_)) => Ok(true),
Ok(None) if self.opt.catch_all => sqlx::query(&self.mappings.query_recipients)
.bind(to_catch_all_address(address))
.fetch_optional(&self.pool)
.fetch(&self.pool)
.try_next()
.await
.map(|id| id.is_some())
.map_err(Into::into),
@ -126,7 +134,8 @@ impl Directory for SqlDirectory {
async fn vrfy(&self, address: &str) -> crate::Result<Vec<String>> {
sqlx::query_scalar::<_, String>(&self.mappings.query_verify)
.bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref())
.fetch_all(&self.pool)
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await
.map_err(Into::into)
}
@ -134,7 +143,8 @@ impl Directory for SqlDirectory {
async fn expn(&self, address: &str) -> crate::Result<Vec<String>> {
sqlx::query_scalar::<_, String>(&self.mappings.query_expand)
.bind(unwrap_subaddress(address, self.opt.subaddressing).as_ref())
.fetch_all(&self.pool)
.fetch(&self.pool)
.try_collect::<Vec<_>>()
.await
.map_err(Into::into)
}
@ -146,7 +156,8 @@ impl Directory for SqlDirectory {
q = q.bind(param);
}
q.fetch_optional(&self.pool)
q.fetch(&self.pool)
.try_next()
.await
.map(|r| r.is_some())
.map_err(Into::into)
@ -155,7 +166,8 @@ impl Directory for SqlDirectory {
async fn is_local_domain(&self, domain: &str) -> crate::Result<bool> {
sqlx::query(&self.mappings.query_domains)
.bind(domain)
.fetch_optional(&self.pool)
.fetch(&self.pool)
.try_next()
.await
.map(|id| id.is_some())
.map_err(Into::into)

View file

@ -37,6 +37,7 @@ const CONFIG: &str = r#"
[directory."sql"]
type = "sql"
address = "sqlite::memory:"
#address = "mysql://root:secret@localhost:3306/stalwart?ssl_mode=disabled"
[directory."sql".options]
catch-all = true