diff --git a/app/custom_domain_validation.py b/app/custom_domain_validation.py index 9ebbd636..ce4b844e 100644 --- a/app/custom_domain_validation.py +++ b/app/custom_domain_validation.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import List, Optional from app import config from app.constants import DMARC_RECORD @@ -11,6 +11,7 @@ from app.dns_utils import ( get_network_dns_client, ) from app.models import CustomDomain +from app.utils import random_string @dataclass @@ -42,6 +43,11 @@ class CustomDomainValidation: and domain.partner_id in self._partner_domain_validation_prefixes ): prefix = self._partner_domain_validation_prefixes[domain.partner_id] + + if not domain.ownership_txt_token: + domain.ownership_txt_token = random_string(30) + Session.commit() + return f"{prefix}-verification={domain.ownership_txt_token}" def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]: @@ -164,9 +170,11 @@ class CustomDomainValidation: else: custom_domain.spf_verified = False Session.commit() + txt_records = self._dns_client.get_txt_record(custom_domain.domain) + cleaned_records = self.__clean_spf_records(txt_records, custom_domain) return DomainValidationResult( success=False, - errors=self._dns_client.get_txt_record(custom_domain.domain), + errors=cleaned_records, ) def validate_dmarc_records( @@ -181,3 +189,13 @@ class CustomDomainValidation: custom_domain.dmarc_verified = False Session.commit() return DomainValidationResult(success=False, errors=txt_records) + + def __clean_spf_records( + self, txt_records: List[str], custom_domain: CustomDomain + ) -> List[str]: + final_records = [] + verification_record = self.get_ownership_verification_record(custom_domain) + for record in txt_records: + if record != verification_record: + final_records.append(record) + return final_records diff --git a/app/dashboard/views/custom_domain.py b/app/dashboard/views/custom_domain.py index b410b306..5b702492 100644 --- a/app/dashboard/views/custom_domain.py +++ b/app/dashboard/views/custom_domain.py @@ -21,7 +21,9 @@ class NewCustomDomainForm(FlaskForm): @parallel_limiter.lock(only_when=lambda: request.method == "POST") def custom_domain(): custom_domains = CustomDomain.filter_by( - user_id=current_user.id, is_sl_subdomain=False + user_id=current_user.id, + is_sl_subdomain=False, + pending_deletion=False, ).all() new_custom_domain_form = NewCustomDomainForm() diff --git a/app/dns_utils.py b/app/dns_utils.py index 995a1e17..202f1099 100644 --- a/app/dns_utils.py +++ b/app/dns_utils.py @@ -106,7 +106,7 @@ class NetworkDNSClient(DNSClient): def get_txt_record(self, hostname: str) -> List[str]: try: - answers = self._resolver.resolve(hostname, "TXT", search=True) + answers = self._resolver.resolve(hostname, "TXT", search=False) ret = [] for a in answers: # type: dns.rdtypes.ANY.TXT.TXT for record in a.strings: diff --git a/tests/test_custom_domain_validation.py b/tests/test_custom_domain_validation.py index a0288b16..d0de3dbe 100644 --- a/tests/test_custom_domain_validation.py +++ b/tests/test_custom_domain_validation.py @@ -470,6 +470,33 @@ def test_custom_domain_validation_validate_spf_records_partner_domain_success(): assert db_domain.spf_verified is True +def test_custom_domain_validation_validate_spf_cleans_verification_record(): + dns_client = InMemoryDNSClient() + proton_partner_id = get_proton_partner().id + + expected_domain = random_domain() + validator = CustomDomainValidation( + dkim_domain=random_domain(), + dns_client=dns_client, + partner_domains={proton_partner_id: expected_domain}, + ) + + domain = create_custom_domain(random_domain()) + domain.partner_id = proton_partner_id + Session.commit() + + wrong_record = random_string() + dns_client.set_txt_record( + hostname=domain.domain, + txt_list=[wrong_record, validator.get_ownership_verification_record(domain)], + ) + res = validator.validate_spf_records(domain) + + assert res.success is False + assert len(res.errors) == 1 + assert res.errors[0] == wrong_record + + # validate_dmarc_records def test_custom_domain_validation_validate_dmarc_records_empty_failure(): dns_client = InMemoryDNSClient()