From b97a1dd52ce0925392f83018684d3f10db518da4 Mon Sep 17 00:00:00 2001 From: Carlos Quintana <74399022+cquintana92@users.noreply.github.com> Date: Wed, 2 Oct 2024 15:46:10 +0200 Subject: [PATCH] fix: improve MX and SPF domain handling (#2246) * fix: improve MX and SPF domain handling * fix: do not use get_instance function * fix: legacy use of tuples instead of MxRecord * refactor: rename partner custom domain variables --- app/config.py | 8 +- app/custom_domain_validation.py | 36 ++++- app/dashboard/views/domain_detail.py | 4 +- app/dns_utils.py | 37 +++--- app/email_utils.py | 2 +- app/models.py | 4 +- cron.py | 7 +- templates/dashboard/domain_detail/dns.html | 8 +- tests/test_custom_domain_validation.py | 148 ++++++++++++++++++++- tests/test_dns_utils.py | 31 +++-- 10 files changed, 240 insertions(+), 45 deletions(-) diff --git a/app/config.py b/app/config.py index 1962976e..0ad1d37e 100644 --- a/app/config.py +++ b/app/config.py @@ -653,9 +653,11 @@ def read_partner_dict(var: str) -> dict[int, str]: return res -PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS") -PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict( - "PARTNER_DOMAIN_VALIDATION_PREFIXES" +PARTNER_DNS_CUSTOM_DOMAINS: dict[int, str] = read_partner_dict( + "PARTNER_DNS_CUSTOM_DOMAINS" +) +PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict( + "PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES" ) MAILBOX_VERIFICATION_OVERRIDE_CODE: Optional[str] = os.environ.get( diff --git a/app/custom_domain_validation.py b/app/custom_domain_validation.py index 4dbf201a..9ebbd636 100644 --- a/app/custom_domain_validation.py +++ b/app/custom_domain_validation.py @@ -5,6 +5,7 @@ from app import config from app.constants import DMARC_RECORD from app.db import Session from app.dns_utils import ( + MxRecord, DNSClient, is_mx_equivalent, get_network_dns_client, @@ -28,10 +29,10 @@ class CustomDomainValidation: ): self.dkim_domain = dkim_domain self._dns_client = dns_client - self._partner_domains = partner_domains or config.PARTNER_DOMAINS + self._partner_domains = partner_domains or config.PARTNER_DNS_CUSTOM_DOMAINS self._partner_domain_validation_prefixes = ( partner_domains_validation_prefixes - or config.PARTNER_DOMAIN_VALIDATION_PREFIXES + or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES ) def get_ownership_verification_record(self, domain: CustomDomain) -> str: @@ -43,6 +44,29 @@ class CustomDomainValidation: prefix = self._partner_domain_validation_prefixes[domain.partner_id] return f"{prefix}-verification={domain.ownership_txt_token}" + def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]: + records = [] + if domain.partner_id is not None and domain.partner_id in self._partner_domains: + domain = self._partner_domains[domain.partner_id] + records.append(MxRecord(10, f"mx1.{domain}.")) + records.append(MxRecord(20, f"mx2.{domain}.")) + else: + # Default ones + for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY: + records.append(MxRecord(priority, domain)) + + return records + + def get_expected_spf_domain(self, domain: CustomDomain) -> str: + if domain.partner_id is not None and domain.partner_id in self._partner_domains: + return self._partner_domains[domain.partner_id] + else: + return config.EMAIL_DOMAIN + + def get_expected_spf_record(self, domain: CustomDomain) -> str: + spf_domain = self.get_expected_spf_domain(domain) + return f"v=spf1 include:{spf_domain} ~all" + def get_dkim_records(self, domain: CustomDomain) -> {str: str}: """ Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not, @@ -116,11 +140,12 @@ class CustomDomainValidation: self, custom_domain: CustomDomain ) -> DomainValidationResult: mx_domains = self._dns_client.get_mx_domains(custom_domain.domain) + expected_mx_records = self.get_expected_mx_records(custom_domain) - if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY): + if not is_mx_equivalent(mx_domains, expected_mx_records): return DomainValidationResult( success=False, - errors=[f"{priority} {domain}" for (priority, domain) in mx_domains], + errors=[f"{record.priority} {record.domain}" for record in mx_domains], ) else: custom_domain.verified = True @@ -131,7 +156,8 @@ class CustomDomainValidation: self, custom_domain: CustomDomain ) -> DomainValidationResult: spf_domains = self._dns_client.get_spf_domain(custom_domain.domain) - if config.EMAIL_DOMAIN in spf_domains: + expected_spf_domain = self.get_expected_spf_domain(custom_domain) + if expected_spf_domain in spf_domains: custom_domain.spf_verified = True Session.commit() return DomainValidationResult(success=True, errors=[]) diff --git a/app/dashboard/views/domain_detail.py b/app/dashboard/views/domain_detail.py index 0911a748..2b1ac32f 100644 --- a/app/dashboard/views/domain_detail.py +++ b/app/dashboard/views/domain_detail.py @@ -36,8 +36,6 @@ def domain_detail_dns(custom_domain_id): custom_domain.ownership_txt_token = random_string(30) Session.commit() - spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all" - domain_validator = CustomDomainValidation(EMAIL_DOMAIN) csrf_form = CSRFValidationForm() @@ -141,7 +139,9 @@ def domain_detail_dns(custom_domain_id): ownership_record=domain_validator.get_ownership_verification_record( custom_domain ), + expected_mx_records=domain_validator.get_expected_mx_records(custom_domain), dkim_records=domain_validator.get_dkim_records(custom_domain), + spf_record=domain_validator.get_expected_spf_record(custom_domain), dmarc_record=DMARC_RECORD, **locals(), ) diff --git a/app/dns_utils.py b/app/dns_utils.py index 2ce69934..995a1e17 100644 --- a/app/dns_utils.py +++ b/app/dns_utils.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import List, Tuple, Optional +from dataclasses import dataclass +from typing import List, Optional import dns.resolver @@ -8,8 +9,14 @@ from app.config import NAMESERVERS _include_spf = "include:" +@dataclass +class MxRecord: + priority: int + domain: str + + def is_mx_equivalent( - mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]] + mx_domains: List[MxRecord], ref_mx_domains: List[MxRecord] ) -> bool: """ Compare mx_domains with ref_mx_domains to see if they are equivalent. @@ -18,14 +25,14 @@ def is_mx_equivalent( The priority order is taken into account but not the priority number. For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)] """ - mx_domains = sorted(mx_domains, key=lambda x: x[0]) - ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0]) + mx_domains = sorted(mx_domains, key=lambda x: x.priority) + ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x.priority) if len(mx_domains) < len(ref_mx_domains): return False - for i in range(len(ref_mx_domains)): - if mx_domains[i][1] != ref_mx_domains[i][1]: + for actual, expected in zip(mx_domains, ref_mx_domains): + if actual.domain != expected.domain: return False return True @@ -37,7 +44,7 @@ class DNSClient(ABC): pass @abstractmethod - def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + def get_mx_domains(self, hostname: str) -> List[MxRecord]: pass def get_spf_domain(self, hostname: str) -> List[str]: @@ -81,7 +88,7 @@ class NetworkDNSClient(DNSClient): except Exception: return None - def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + def get_mx_domains(self, hostname: str) -> List[MxRecord]: """ return list of (priority, domain name) sorted by priority (lowest priority first) domain name ends with a "." at the end. @@ -92,8 +99,8 @@ class NetworkDNSClient(DNSClient): for a in answers: record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.' parts = record.split(" ") - ret.append((int(parts[0]), parts[1])) - return sorted(ret, key=lambda x: x[0]) + ret.append(MxRecord(priority=int(parts[0]), domain=parts[1])) + return sorted(ret, key=lambda x: x.priority) except Exception: return [] @@ -112,14 +119,14 @@ class NetworkDNSClient(DNSClient): class InMemoryDNSClient(DNSClient): def __init__(self): self.cname_records: dict[str, Optional[str]] = {} - self.mx_records: dict[str, List[Tuple[int, str]]] = {} + self.mx_records: dict[str, List[MxRecord]] = {} self.spf_records: dict[str, List[str]] = {} self.txt_records: dict[str, List[str]] = {} def set_cname_record(self, hostname: str, cname: str): self.cname_records[hostname] = cname - def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]): + def set_mx_records(self, hostname: str, mx_list: List[MxRecord]): self.mx_records[hostname] = mx_list def set_txt_record(self, hostname: str, txt_list: List[str]): @@ -128,9 +135,9 @@ class InMemoryDNSClient(DNSClient): def get_cname_record(self, hostname: str) -> Optional[str]: return self.cname_records.get(hostname) - def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + def get_mx_domains(self, hostname: str) -> List[MxRecord]: mx_list = self.mx_records.get(hostname, []) - return sorted(mx_list, key=lambda x: x[0]) + return sorted(mx_list, key=lambda x: x.priority) def get_txt_record(self, hostname: str) -> List[str]: return self.txt_records.get(hostname, []) @@ -140,5 +147,5 @@ def get_network_dns_client() -> NetworkDNSClient: return NetworkDNSClient(NAMESERVERS) -def get_mx_domains(hostname: str) -> [(int, str)]: +def get_mx_domains(hostname: str) -> List[MxRecord]: return get_network_dns_client().get_mx_domains(hostname) diff --git a/app/email_utils.py b/app/email_utils.py index 5ff34d06..ca5aa041 100644 --- a/app/email_utils.py +++ b/app/email_utils.py @@ -657,7 +657,7 @@ def get_mx_domain_list(domain) -> [str]: """ priority_domains = get_mx_domains(domain) - return [d[:-1] for _, d in priority_domains] + return [d.domain[:-1] for d in priority_domains] def personal_email_already_used(email_address: str) -> bool: diff --git a/app/models.py b/app/models.py index c86c5df6..092c1060 100644 --- a/app/models.py +++ b/app/models.py @@ -2766,9 +2766,9 @@ class Mailbox(Base, ModelMixin): from app.email_utils import get_email_local_part - mx_domains: [(int, str)] = get_mx_domains(get_email_local_part(self.email)) + mx_domains = get_mx_domains(get_email_local_part(self.email)) # Proton is the first domain - if mx_domains and mx_domains[0][1] in ( + if mx_domains and mx_domains[0].domain in ( "mail.protonmail.ch.", "mailsec.protonmail.ch.", ): diff --git a/cron.py b/cron.py index bfc150d9..f8bb09f4 100644 --- a/cron.py +++ b/cron.py @@ -14,6 +14,7 @@ from sqlalchemy.sql import Insert, text from app import s3, config from app.alias_utils import nb_email_log_for_mailbox from app.api.views.apple import verify_receipt +from app.custom_domain_validation import CustomDomainValidation from app.db import Session from app.dns_utils import get_mx_domains, is_mx_equivalent from app.email_utils import ( @@ -905,9 +906,11 @@ def check_custom_domain(): LOG.i("custom domain has been deleted") -def check_single_custom_domain(custom_domain): +def check_single_custom_domain(custom_domain: CustomDomain): mx_domains = get_mx_domains(custom_domain.domain) - if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY): + validator = CustomDomainValidation(dkim_domain=config.EMAIL_DOMAIN) + expected_custom_domains = validator.get_expected_mx_records(custom_domain) + if not is_mx_equivalent(mx_domains, expected_custom_domains): user = custom_domain.user LOG.w( "The MX record is not correctly set for %s %s %s", diff --git a/templates/dashboard/domain_detail/dns.html b/templates/dashboard/domain_detail/dns.html index 4058f5ea..e100feea 100644 --- a/templates/dashboard/domain_detail/dns.html +++ b/templates/dashboard/domain_detail/dns.html @@ -91,7 +91,8 @@
Some domain registrars (Namecheap, CloudFlare, etc) might also use @ for the root domain. - {% for priority, email_server in EMAIL_SERVERS_WITH_PRIORITY %} + + {% for record in expected_mx_records %}
Record: MX @@ -99,14 +100,15 @@ Domain: {{ custom_domain.domain }} or @
- Priority: {{ priority }} + Priority: {{ record.priority }}
Target: {{ email_server }} + data-clipboard-text="{{ record.domain }}">{{ record.domain }}
{% endfor %} +
{{ csrf_form.csrf_token }} diff --git a/tests/test_custom_domain_validation.py b/tests/test_custom_domain_validation.py index b6e4386d..a0288b16 100644 --- a/tests/test_custom_domain_validation.py +++ b/tests/test_custom_domain_validation.py @@ -5,7 +5,7 @@ from app.constants import DMARC_RECORD from app.custom_domain_validation import CustomDomainValidation from app.db import Session from app.models import CustomDomain, User -from app.dns_utils import InMemoryDNSClient +from app.dns_utils import InMemoryDNSClient, MxRecord from app.proton.utils import get_proton_partner from app.utils import random_string from tests.utils import create_new_user, random_domain @@ -58,6 +58,123 @@ def test_custom_domain_validation_get_dkim_records_for_partner(): assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}" +# get_expected_mx_records +def test_custom_domain_validation_get_expected_mx_records_regular_domain(): + domain = random_domain() + custom_domain = create_custom_domain(domain) + + partner_id = get_proton_partner().id + + dkim_domain = random_domain() + validator = CustomDomainValidation( + domain, partner_domains={partner_id: dkim_domain} + ) + records = validator.get_expected_mx_records(custom_domain) + # As the domain is not a partner_domain,default records should be used even if + # there is a config for the partner + assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY) + for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)): + config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i] + assert records[i].priority == config_record[0] + assert records[i].domain == config_record[1] + + +def test_custom_domain_validation_get_expected_mx_records_domain_from_partner(): + domain = random_domain() + custom_domain = create_custom_domain(domain) + + partner_id = get_proton_partner().id + custom_domain.partner_id = partner_id + Session.commit() + + dkim_domain = random_domain() + validator = CustomDomainValidation(dkim_domain) + records = validator.get_expected_mx_records(custom_domain) + # As the domain is a partner_domain but there is no custom config for partner, default records + # should be used + assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY) + for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)): + config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i] + assert records[i].priority == config_record[0] + assert records[i].domain == config_record[1] + + +def test_custom_domain_validation_get_expected_mx_records_domain_from_partner_with_custom_config(): + domain = random_domain() + custom_domain = create_custom_domain(domain) + + partner_id = get_proton_partner().id + custom_domain.partner_id = partner_id + Session.commit() + + dkim_domain = random_domain() + expected_mx_domain = random_domain() + validator = CustomDomainValidation( + dkim_domain, partner_domains={partner_id: expected_mx_domain} + ) + records = validator.get_expected_mx_records(custom_domain) + # As the domain is a partner_domain and there is a custom config for partner, partner records + # should be used + assert len(records) == 2 + + assert records[0].priority == 10 + assert records[0].domain == f"mx1.{expected_mx_domain}." + assert records[1].priority == 20 + assert records[1].domain == f"mx2.{expected_mx_domain}." + + +# get_expected_spf_records +def test_custom_domain_validation_get_expected_spf_record_regular_domain(): + domain = random_domain() + custom_domain = create_custom_domain(domain) + + partner_id = get_proton_partner().id + + dkim_domain = random_domain() + validator = CustomDomainValidation( + domain, partner_domains={partner_id: dkim_domain} + ) + record = validator.get_expected_spf_record(custom_domain) + # As the domain is not a partner_domain, default records should be used even if + # there is a config for the partner + assert record == f"v=spf1 include:{config.EMAIL_DOMAIN} ~all" + + +def test_custom_domain_validation_get_expected_spf_record_domain_from_partner(): + domain = random_domain() + custom_domain = create_custom_domain(domain) + + partner_id = get_proton_partner().id + custom_domain.partner_id = partner_id + Session.commit() + + dkim_domain = random_domain() + validator = CustomDomainValidation(dkim_domain) + record = validator.get_expected_spf_record(custom_domain) + # As the domain is a partner_domain but there is no custom config for partner, default records + # should be used + assert record == f"v=spf1 include:{config.EMAIL_DOMAIN} ~all" + + +def test_custom_domain_validation_get_expected_spf_record_domain_from_partner_with_custom_config(): + domain = random_domain() + custom_domain = create_custom_domain(domain) + + partner_id = get_proton_partner().id + custom_domain.partner_id = partner_id + Session.commit() + + dkim_domain = random_domain() + expected_mx_domain = random_domain() + validator = CustomDomainValidation( + dkim_domain, partner_domains={partner_id: expected_mx_domain} + ) + record = validator.get_expected_spf_record(custom_domain) + # As the domain is a partner_domain and there is a custom config for partner, partner records + # should be used + assert record == f"v=spf1 include:{expected_mx_domain} ~all" + + # validate_dkim_records def test_custom_domain_validation_validate_dkim_records_empty_records_failure(): dns_client = InMemoryDNSClient() @@ -253,7 +370,7 @@ def test_custom_domain_validation_validate_mx_records_wrong_records_failure(): wrong_record_1 = random_string() wrong_record_2 = random_string() - wrong_records = [(10, wrong_record_1), (20, wrong_record_2)] + wrong_records = [MxRecord(10, wrong_record_1), MxRecord(20, wrong_record_2)] dns_client.set_mx_records(domain.domain, wrong_records) res = validator.validate_mx_records(domain) @@ -270,7 +387,7 @@ def test_custom_domain_validation_validate_mx_records_success(): domain = create_custom_domain(random_domain()) - dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY) + dns_client.set_mx_records(domain.domain, validator.get_expected_mx_records(domain)) res = validator.validate_mx_records(domain) assert res.success is True @@ -328,6 +445,31 @@ def test_custom_domain_validation_validate_spf_records_success(): assert db_domain.spf_verified is True +def test_custom_domain_validation_validate_spf_records_partner_domain_success(): + dns_client = InMemoryDNSClient() + proton_partner_id = get_proton_partner().id + + expected_domain = random_domain() + validator = CustomDomainValidation( + dkim_domain=random_domain(), + dns_client=dns_client, + partner_domains={proton_partner_id: expected_domain}, + ) + + domain = create_custom_domain(random_domain()) + domain.partner_id = proton_partner_id + Session.commit() + + dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{expected_domain}"]) + res = validator.validate_spf_records(domain) + + assert res.success is True + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.spf_verified is True + + # validate_dmarc_records def test_custom_domain_validation_validate_dmarc_records_empty_failure(): dns_client = InMemoryDNSClient() diff --git a/tests/test_dns_utils.py b/tests/test_dns_utils.py index 374983c8..15b2b9af 100644 --- a/tests/test_dns_utils.py +++ b/tests/test_dns_utils.py @@ -3,6 +3,7 @@ from app.dns_utils import ( get_network_dns_client, is_mx_equivalent, InMemoryDNSClient, + MxRecord, ) from tests.utils import random_domain @@ -17,8 +18,8 @@ def test_get_mx_domains(): assert len(r) > 0 for x in r: - assert x[0] > 0 - assert x[1] + assert x.priority > 0 + assert x.domain def test_get_spf_domain(): @@ -33,20 +34,32 @@ def test_get_txt_record(): def test_is_mx_equivalent(): assert is_mx_equivalent([], []) - assert is_mx_equivalent([(1, "domain")], [(1, "domain")]) assert is_mx_equivalent( - [(10, "domain1"), (20, "domain2")], [(10, "domain1"), (20, "domain2")] + mx_domains=[MxRecord(1, "domain")], ref_mx_domains=[MxRecord(1, "domain")] ) assert is_mx_equivalent( - [(5, "domain1"), (10, "domain2")], [(10, "domain1"), (20, "domain2")] + mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")], + ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")], ) assert is_mx_equivalent( - [(5, "domain1"), (10, "domain2"), (20, "domain3")], - [(10, "domain1"), (20, "domain2")], + mx_domains=[MxRecord(5, "domain1"), MxRecord(10, "domain2")], + ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")], + ) + assert is_mx_equivalent( + mx_domains=[ + MxRecord(5, "domain1"), + MxRecord(10, "domain2"), + MxRecord(20, "domain3"), + ], + ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")], ) assert not is_mx_equivalent( - [(5, "domain1"), (10, "domain2")], - [(10, "domain1"), (20, "domain2"), (20, "domain3")], + mx_domains=[MxRecord(5, "domain1"), MxRecord(10, "domain2")], + ref_mx_domains=[ + MxRecord(10, "domain1"), + MxRecord(20, "domain2"), + MxRecord(20, "domain3"), + ], )