chore: DNS validation improvements (#2248)

* chore: DNS validation improvements

* fix: do not show domains pending deletion

* fix: generate verification token if null

* revert: dmarc cleanup
This commit is contained in:
Carlos Quintana 2024-10-03 13:04:17 +02:00 committed by GitHub
parent 06ab116476
commit 9d5697b624
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 51 additions and 4 deletions

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional
from app import config
from app.constants import DMARC_RECORD
@ -11,6 +11,7 @@ from app.dns_utils import (
get_network_dns_client,
)
from app.models import CustomDomain
from app.utils import random_string
@dataclass
@ -42,6 +43,11 @@ class CustomDomainValidation:
and domain.partner_id in self._partner_domain_validation_prefixes
):
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
if not domain.ownership_txt_token:
domain.ownership_txt_token = random_string(30)
Session.commit()
return f"{prefix}-verification={domain.ownership_txt_token}"
def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]:
@ -164,9 +170,11 @@ class CustomDomainValidation:
else:
custom_domain.spf_verified = False
Session.commit()
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
cleaned_records = self.__clean_spf_records(txt_records, custom_domain)
return DomainValidationResult(
success=False,
errors=self._dns_client.get_txt_record(custom_domain.domain),
errors=cleaned_records,
)
def validate_dmarc_records(
@ -181,3 +189,13 @@ class CustomDomainValidation:
custom_domain.dmarc_verified = False
Session.commit()
return DomainValidationResult(success=False, errors=txt_records)
def __clean_spf_records(
self, txt_records: List[str], custom_domain: CustomDomain
) -> List[str]:
final_records = []
verification_record = self.get_ownership_verification_record(custom_domain)
for record in txt_records:
if record != verification_record:
final_records.append(record)
return final_records

View file

@ -21,7 +21,9 @@ class NewCustomDomainForm(FlaskForm):
@parallel_limiter.lock(only_when=lambda: request.method == "POST")
def custom_domain():
custom_domains = CustomDomain.filter_by(
user_id=current_user.id, is_sl_subdomain=False
user_id=current_user.id,
is_sl_subdomain=False,
pending_deletion=False,
).all()
new_custom_domain_form = NewCustomDomainForm()

View file

@ -106,7 +106,7 @@ class NetworkDNSClient(DNSClient):
def get_txt_record(self, hostname: str) -> List[str]:
try:
answers = self._resolver.resolve(hostname, "TXT", search=True)
answers = self._resolver.resolve(hostname, "TXT", search=False)
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:

View file

@ -470,6 +470,33 @@ def test_custom_domain_validation_validate_spf_records_partner_domain_success():
assert db_domain.spf_verified is True
def test_custom_domain_validation_validate_spf_cleans_verification_record():
dns_client = InMemoryDNSClient()
proton_partner_id = get_proton_partner().id
expected_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain=random_domain(),
dns_client=dns_client,
partner_domains={proton_partner_id: expected_domain},
)
domain = create_custom_domain(random_domain())
domain.partner_id = proton_partner_id
Session.commit()
wrong_record = random_string()
dns_client.set_txt_record(
hostname=domain.domain,
txt_list=[wrong_record, validator.get_ownership_verification_record(domain)],
)
res = validator.validate_spf_records(domain)
assert res.success is False
assert len(res.errors) == 1
assert res.errors[0] == wrong_record
# validate_dmarc_records
def test_custom_domain_validation_validate_dmarc_records_empty_failure():
dns_client = InMemoryDNSClient()