diff --git a/app/config.py b/app/config.py
index 1962976e..0ad1d37e 100644
--- a/app/config.py
+++ b/app/config.py
@@ -653,9 +653,11 @@ def read_partner_dict(var: str) -> dict[int, str]:
return res
-PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS")
-PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
- "PARTNER_DOMAIN_VALIDATION_PREFIXES"
+PARTNER_DNS_CUSTOM_DOMAINS: dict[int, str] = read_partner_dict(
+ "PARTNER_DNS_CUSTOM_DOMAINS"
+)
+PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
+ "PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES"
)
MAILBOX_VERIFICATION_OVERRIDE_CODE: Optional[str] = os.environ.get(
diff --git a/app/custom_domain_validation.py b/app/custom_domain_validation.py
index 4dbf201a..9ebbd636 100644
--- a/app/custom_domain_validation.py
+++ b/app/custom_domain_validation.py
@@ -5,6 +5,7 @@ from app import config
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import (
+ MxRecord,
DNSClient,
is_mx_equivalent,
get_network_dns_client,
@@ -28,10 +29,10 @@ class CustomDomainValidation:
):
self.dkim_domain = dkim_domain
self._dns_client = dns_client
- self._partner_domains = partner_domains or config.PARTNER_DOMAINS
+ self._partner_domains = partner_domains or config.PARTNER_DNS_CUSTOM_DOMAINS
self._partner_domain_validation_prefixes = (
partner_domains_validation_prefixes
- or config.PARTNER_DOMAIN_VALIDATION_PREFIXES
+ or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES
)
def get_ownership_verification_record(self, domain: CustomDomain) -> str:
@@ -43,6 +44,29 @@ class CustomDomainValidation:
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
return f"{prefix}-verification={domain.ownership_txt_token}"
+ def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]:
+ records = []
+ if domain.partner_id is not None and domain.partner_id in self._partner_domains:
+ domain = self._partner_domains[domain.partner_id]
+ records.append(MxRecord(10, f"mx1.{domain}."))
+ records.append(MxRecord(20, f"mx2.{domain}."))
+ else:
+ # Default ones
+ for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY:
+ records.append(MxRecord(priority, domain))
+
+ return records
+
+ def get_expected_spf_domain(self, domain: CustomDomain) -> str:
+ if domain.partner_id is not None and domain.partner_id in self._partner_domains:
+ return self._partner_domains[domain.partner_id]
+ else:
+ return config.EMAIL_DOMAIN
+
+ def get_expected_spf_record(self, domain: CustomDomain) -> str:
+ spf_domain = self.get_expected_spf_domain(domain)
+ return f"v=spf1 include:{spf_domain} ~all"
+
def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
"""
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
@@ -116,11 +140,12 @@ class CustomDomainValidation:
self, custom_domain: CustomDomain
) -> DomainValidationResult:
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
+ expected_mx_records = self.get_expected_mx_records(custom_domain)
- if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY):
+ if not is_mx_equivalent(mx_domains, expected_mx_records):
return DomainValidationResult(
success=False,
- errors=[f"{priority} {domain}" for (priority, domain) in mx_domains],
+ errors=[f"{record.priority} {record.domain}" for record in mx_domains],
)
else:
custom_domain.verified = True
@@ -131,7 +156,8 @@ class CustomDomainValidation:
self, custom_domain: CustomDomain
) -> DomainValidationResult:
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
- if config.EMAIL_DOMAIN in spf_domains:
+ expected_spf_domain = self.get_expected_spf_domain(custom_domain)
+ if expected_spf_domain in spf_domains:
custom_domain.spf_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
diff --git a/app/dashboard/views/domain_detail.py b/app/dashboard/views/domain_detail.py
index 0911a748..2b1ac32f 100644
--- a/app/dashboard/views/domain_detail.py
+++ b/app/dashboard/views/domain_detail.py
@@ -36,8 +36,6 @@ def domain_detail_dns(custom_domain_id):
custom_domain.ownership_txt_token = random_string(30)
Session.commit()
- spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all"
-
domain_validator = CustomDomainValidation(EMAIL_DOMAIN)
csrf_form = CSRFValidationForm()
@@ -141,7 +139,9 @@ def domain_detail_dns(custom_domain_id):
ownership_record=domain_validator.get_ownership_verification_record(
custom_domain
),
+ expected_mx_records=domain_validator.get_expected_mx_records(custom_domain),
dkim_records=domain_validator.get_dkim_records(custom_domain),
+ spf_record=domain_validator.get_expected_spf_record(custom_domain),
dmarc_record=DMARC_RECORD,
**locals(),
)
diff --git a/app/dns_utils.py b/app/dns_utils.py
index 2ce69934..995a1e17 100644
--- a/app/dns_utils.py
+++ b/app/dns_utils.py
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
-from typing import List, Tuple, Optional
+from dataclasses import dataclass
+from typing import List, Optional
import dns.resolver
@@ -8,8 +9,14 @@ from app.config import NAMESERVERS
_include_spf = "include:"
+@dataclass
+class MxRecord:
+ priority: int
+ domain: str
+
+
def is_mx_equivalent(
- mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
+ mx_domains: List[MxRecord], ref_mx_domains: List[MxRecord]
) -> bool:
"""
Compare mx_domains with ref_mx_domains to see if they are equivalent.
@@ -18,14 +25,14 @@ def is_mx_equivalent(
The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
- mx_domains = sorted(mx_domains, key=lambda x: x[0])
- ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0])
+ mx_domains = sorted(mx_domains, key=lambda x: x.priority)
+ ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x.priority)
if len(mx_domains) < len(ref_mx_domains):
return False
- for i in range(len(ref_mx_domains)):
- if mx_domains[i][1] != ref_mx_domains[i][1]:
+ for actual, expected in zip(mx_domains, ref_mx_domains):
+ if actual.domain != expected.domain:
return False
return True
@@ -37,7 +44,7 @@ class DNSClient(ABC):
pass
@abstractmethod
- def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
+ def get_mx_domains(self, hostname: str) -> List[MxRecord]:
pass
def get_spf_domain(self, hostname: str) -> List[str]:
@@ -81,7 +88,7 @@ class NetworkDNSClient(DNSClient):
except Exception:
return None
- def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
+ def get_mx_domains(self, hostname: str) -> List[MxRecord]:
"""
return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
@@ -92,8 +99,8 @@ class NetworkDNSClient(DNSClient):
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
- ret.append((int(parts[0]), parts[1]))
- return sorted(ret, key=lambda x: x[0])
+ ret.append(MxRecord(priority=int(parts[0]), domain=parts[1]))
+ return sorted(ret, key=lambda x: x.priority)
except Exception:
return []
@@ -112,14 +119,14 @@ class NetworkDNSClient(DNSClient):
class InMemoryDNSClient(DNSClient):
def __init__(self):
self.cname_records: dict[str, Optional[str]] = {}
- self.mx_records: dict[str, List[Tuple[int, str]]] = {}
+ self.mx_records: dict[str, List[MxRecord]] = {}
self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {}
def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname
- def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
+ def set_mx_records(self, hostname: str, mx_list: List[MxRecord]):
self.mx_records[hostname] = mx_list
def set_txt_record(self, hostname: str, txt_list: List[str]):
@@ -128,9 +135,9 @@ class InMemoryDNSClient(DNSClient):
def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname)
- def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
+ def get_mx_domains(self, hostname: str) -> List[MxRecord]:
mx_list = self.mx_records.get(hostname, [])
- return sorted(mx_list, key=lambda x: x[0])
+ return sorted(mx_list, key=lambda x: x.priority)
def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, [])
@@ -140,5 +147,5 @@ def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS)
-def get_mx_domains(hostname: str) -> [(int, str)]:
+def get_mx_domains(hostname: str) -> List[MxRecord]:
return get_network_dns_client().get_mx_domains(hostname)
diff --git a/app/email_utils.py b/app/email_utils.py
index 5ff34d06..ca5aa041 100644
--- a/app/email_utils.py
+++ b/app/email_utils.py
@@ -657,7 +657,7 @@ def get_mx_domain_list(domain) -> [str]:
"""
priority_domains = get_mx_domains(domain)
- return [d[:-1] for _, d in priority_domains]
+ return [d.domain[:-1] for d in priority_domains]
def personal_email_already_used(email_address: str) -> bool:
diff --git a/app/models.py b/app/models.py
index c86c5df6..092c1060 100644
--- a/app/models.py
+++ b/app/models.py
@@ -2766,9 +2766,9 @@ class Mailbox(Base, ModelMixin):
from app.email_utils import get_email_local_part
- mx_domains: [(int, str)] = get_mx_domains(get_email_local_part(self.email))
+ mx_domains = get_mx_domains(get_email_local_part(self.email))
# Proton is the first domain
- if mx_domains and mx_domains[0][1] in (
+ if mx_domains and mx_domains[0].domain in (
"mail.protonmail.ch.",
"mailsec.protonmail.ch.",
):
diff --git a/cron.py b/cron.py
index bfc150d9..f8bb09f4 100644
--- a/cron.py
+++ b/cron.py
@@ -14,6 +14,7 @@ from sqlalchemy.sql import Insert, text
from app import s3, config
from app.alias_utils import nb_email_log_for_mailbox
from app.api.views.apple import verify_receipt
+from app.custom_domain_validation import CustomDomainValidation
from app.db import Session
from app.dns_utils import get_mx_domains, is_mx_equivalent
from app.email_utils import (
@@ -905,9 +906,11 @@ def check_custom_domain():
LOG.i("custom domain has been deleted")
-def check_single_custom_domain(custom_domain):
+def check_single_custom_domain(custom_domain: CustomDomain):
mx_domains = get_mx_domains(custom_domain.domain)
- if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY):
+ validator = CustomDomainValidation(dkim_domain=config.EMAIL_DOMAIN)
+ expected_custom_domains = validator.get_expected_mx_records(custom_domain)
+ if not is_mx_equivalent(mx_domains, expected_custom_domains):
user = custom_domain.user
LOG.w(
"The MX record is not correctly set for %s %s %s",
diff --git a/templates/dashboard/domain_detail/dns.html b/templates/dashboard/domain_detail/dns.html
index 4058f5ea..e100feea 100644
--- a/templates/dashboard/domain_detail/dns.html
+++ b/templates/dashboard/domain_detail/dns.html
@@ -91,7 +91,8 @@
Some domain registrars (Namecheap, CloudFlare, etc) might also use @ for the root domain.
- {% for priority, email_server in EMAIL_SERVERS_WITH_PRIORITY %}
+
+ {% for record in expected_mx_records %}