app/tests/test_custom_domain_validation.py
Carlos Quintana 9d5697b624
chore: DNS validation improvements (#2248)
* chore: DNS validation improvements

* fix: do not show domains pending deletion

* fix: generate verification token if null

* revert: dmarc cleanup
2024-10-03 13:04:17 +02:00

545 lines
18 KiB
Python

from typing import Optional
from app import config
from app.constants import DMARC_RECORD
from app.custom_domain_validation import CustomDomainValidation
from app.db import Session
from app.models import CustomDomain, User
from app.dns_utils import InMemoryDNSClient, MxRecord
from app.proton.utils import get_proton_partner
from app.utils import random_string
from tests.utils import create_new_user, random_domain
user: Optional[User] = None
def setup_module():
global user
config.SKIP_MX_LOOKUP_ON_CHECK = True
user = create_new_user()
user.trial_end = None
user.lifetime = True
Session.commit()
def create_custom_domain(domain: str) -> CustomDomain:
return CustomDomain.create(user_id=user.id, domain=domain, commit=True)
def test_custom_domain_validation_get_dkim_records():
domain = random_domain()
custom_domain = create_custom_domain(domain)
validator = CustomDomainValidation(domain)
records = validator.get_dkim_records(custom_domain)
assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
def test_custom_domain_validation_get_dkim_records_for_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
records = validator.get_dkim_records(custom_domain)
assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{dkim_domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{dkim_domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}"
# get_expected_mx_records
def test_custom_domain_validation_get_expected_mx_records_regular_domain():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
records = validator.get_expected_mx_records(custom_domain)
# As the domain is not a partner_domain,default records should be used even if
# there is a config for the partner
assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY)
for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)):
config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i]
assert records[i].priority == config_record[0]
assert records[i].domain == config_record[1]
def test_custom_domain_validation_get_expected_mx_records_domain_from_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(dkim_domain)
records = validator.get_expected_mx_records(custom_domain)
# As the domain is a partner_domain but there is no custom config for partner, default records
# should be used
assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY)
for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)):
config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i]
assert records[i].priority == config_record[0]
assert records[i].domain == config_record[1]
def test_custom_domain_validation_get_expected_mx_records_domain_from_partner_with_custom_config():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
expected_mx_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain, partner_domains={partner_id: expected_mx_domain}
)
records = validator.get_expected_mx_records(custom_domain)
# As the domain is a partner_domain and there is a custom config for partner, partner records
# should be used
assert len(records) == 2
assert records[0].priority == 10
assert records[0].domain == f"mx1.{expected_mx_domain}."
assert records[1].priority == 20
assert records[1].domain == f"mx2.{expected_mx_domain}."
# get_expected_spf_records
def test_custom_domain_validation_get_expected_spf_record_regular_domain():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
record = validator.get_expected_spf_record(custom_domain)
# As the domain is not a partner_domain, default records should be used even if
# there is a config for the partner
assert record == f"v=spf1 include:{config.EMAIL_DOMAIN} ~all"
def test_custom_domain_validation_get_expected_spf_record_domain_from_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(dkim_domain)
record = validator.get_expected_spf_record(custom_domain)
# As the domain is a partner_domain but there is no custom config for partner, default records
# should be used
assert record == f"v=spf1 include:{config.EMAIL_DOMAIN} ~all"
def test_custom_domain_validation_get_expected_spf_record_domain_from_partner_with_custom_config():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
expected_mx_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain, partner_domains={partner_id: expected_mx_domain}
)
record = validator.get_expected_spf_record(custom_domain)
# As the domain is a partner_domain and there is a custom config for partner, partner records
# should be used
assert record == f"v=spf1 include:{expected_mx_domain} ~all"
# validate_dkim_records
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_dkim_records(domain)
assert len(res) == 3
for record_value in res.values():
assert record_value == "empty"
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is False
def test_custom_domain_validation_validate_dkim_records_wrong_records_failure():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, two domains wrong
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong")
dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong")
domain = create_custom_domain(user_domain)
res = validator.validate_dkim_records(domain)
assert len(res) == 2
for record_value in res.values():
assert record_value == "wrong"
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is False
def test_custom_domain_validation_validate_dkim_records_success_with_old_system():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, other domains missing
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
domain = create_custom_domain(user_domain)
# DKIM is verified
domain.dkim_verified = True
Session.commit()
res = validator.validate_dkim_records(domain)
assert len(res) == 2
assert f"dkim02._domainkey.{user_domain}" in res
assert f"dkim03._domainkey.{user_domain}" in res
# Flag is not cleared
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is True
def test_custom_domain_validation_validate_dkim_records_success():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, two domains wrong
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(
f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(
f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}"
)
domain = create_custom_domain(user_domain)
res = validator.validate_dkim_records(domain)
assert len(res) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is True
# validate_ownership
def test_custom_domain_validation_validate_ownership_empty_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_domain_ownership(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is False
def test_custom_domain_validation_validate_ownership_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(domain.domain, wrong_records)
res = validator.validate_domain_ownership(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is False
def test_custom_domain_validation_validate_ownership_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)]
)
res = validator.validate_domain_ownership(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is True
def test_custom_domain_validation_validate_ownership_from_partner_success():
dns_client = InMemoryDNSClient()
partner_id = get_proton_partner().id
prefix = random_string()
validator = CustomDomainValidation(
random_domain(),
dns_client,
partner_domains_validation_prefixes={partner_id: prefix},
)
domain = create_custom_domain(random_domain())
domain.partner_id = partner_id
Session.commit()
dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)]
)
res = validator.validate_domain_ownership(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is True
# validate_mx_records
def test_custom_domain_validation_validate_mx_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_mx_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is False
def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_record_1 = random_string()
wrong_record_2 = random_string()
wrong_records = [MxRecord(10, wrong_record_1), MxRecord(20, wrong_record_2)]
dns_client.set_mx_records(domain.domain, wrong_records)
res = validator.validate_mx_records(domain)
assert res.success is False
assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"]
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is False
def test_custom_domain_validation_validate_mx_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_mx_records(domain.domain, validator.get_expected_mx_records(domain))
res = validator.validate_mx_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is True
# validate_spf_records
def test_custom_domain_validation_validate_spf_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_spf_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is False
def test_custom_domain_validation_validate_spf_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(domain.domain, wrong_records)
res = validator.validate_spf_records(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is False
def test_custom_domain_validation_validate_spf_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"])
res = validator.validate_spf_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is True
def test_custom_domain_validation_validate_spf_records_partner_domain_success():
dns_client = InMemoryDNSClient()
proton_partner_id = get_proton_partner().id
expected_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain=random_domain(),
dns_client=dns_client,
partner_domains={proton_partner_id: expected_domain},
)
domain = create_custom_domain(random_domain())
domain.partner_id = proton_partner_id
Session.commit()
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{expected_domain}"])
res = validator.validate_spf_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is True
def test_custom_domain_validation_validate_spf_cleans_verification_record():
dns_client = InMemoryDNSClient()
proton_partner_id = get_proton_partner().id
expected_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain=random_domain(),
dns_client=dns_client,
partner_domains={proton_partner_id: expected_domain},
)
domain = create_custom_domain(random_domain())
domain.partner_id = proton_partner_id
Session.commit()
wrong_record = random_string()
dns_client.set_txt_record(
hostname=domain.domain,
txt_list=[wrong_record, validator.get_ownership_verification_record(domain)],
)
res = validator.validate_spf_records(domain)
assert res.success is False
assert len(res.errors) == 1
assert res.errors[0] == wrong_record
# validate_dmarc_records
def test_custom_domain_validation_validate_dmarc_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_dmarc_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is False
def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records)
res = validator.validate_dmarc_records(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is False
def test_custom_domain_validation_validate_dmarc_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD])
res = validator.validate_dmarc_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is True