mirror of
https://github.com/simple-login/app.git
synced 2024-11-15 05:07:33 +08:00
9d5697b624
* chore: DNS validation improvements * fix: do not show domains pending deletion * fix: generate verification token if null * revert: dmarc cleanup
545 lines
18 KiB
Python
545 lines
18 KiB
Python
from typing import Optional
|
|
|
|
from app import config
|
|
from app.constants import DMARC_RECORD
|
|
from app.custom_domain_validation import CustomDomainValidation
|
|
from app.db import Session
|
|
from app.models import CustomDomain, User
|
|
from app.dns_utils import InMemoryDNSClient, MxRecord
|
|
from app.proton.utils import get_proton_partner
|
|
from app.utils import random_string
|
|
from tests.utils import create_new_user, random_domain
|
|
|
|
user: Optional[User] = None
|
|
|
|
|
|
def setup_module():
|
|
global user
|
|
config.SKIP_MX_LOOKUP_ON_CHECK = True
|
|
user = create_new_user()
|
|
user.trial_end = None
|
|
user.lifetime = True
|
|
Session.commit()
|
|
|
|
|
|
def create_custom_domain(domain: str) -> CustomDomain:
|
|
return CustomDomain.create(user_id=user.id, domain=domain, commit=True)
|
|
|
|
|
|
def test_custom_domain_validation_get_dkim_records():
|
|
domain = random_domain()
|
|
custom_domain = create_custom_domain(domain)
|
|
validator = CustomDomainValidation(domain)
|
|
records = validator.get_dkim_records(custom_domain)
|
|
|
|
assert len(records) == 3
|
|
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
|
|
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}"
|
|
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
|
|
|
|
|
|
def test_custom_domain_validation_get_dkim_records_for_partner():
|
|
domain = random_domain()
|
|
custom_domain = create_custom_domain(domain)
|
|
|
|
partner_id = get_proton_partner().id
|
|
custom_domain.partner_id = partner_id
|
|
Session.commit()
|
|
|
|
dkim_domain = random_domain()
|
|
validator = CustomDomainValidation(
|
|
domain, partner_domains={partner_id: dkim_domain}
|
|
)
|
|
records = validator.get_dkim_records(custom_domain)
|
|
|
|
assert len(records) == 3
|
|
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{dkim_domain}"
|
|
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{dkim_domain}"
|
|
assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}"
|
|
|
|
|
|
# get_expected_mx_records
|
|
def test_custom_domain_validation_get_expected_mx_records_regular_domain():
|
|
domain = random_domain()
|
|
custom_domain = create_custom_domain(domain)
|
|
|
|
partner_id = get_proton_partner().id
|
|
|
|
dkim_domain = random_domain()
|
|
validator = CustomDomainValidation(
|
|
domain, partner_domains={partner_id: dkim_domain}
|
|
)
|
|
records = validator.get_expected_mx_records(custom_domain)
|
|
# As the domain is not a partner_domain,default records should be used even if
|
|
# there is a config for the partner
|
|
assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY)
|
|
for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)):
|
|
config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i]
|
|
assert records[i].priority == config_record[0]
|
|
assert records[i].domain == config_record[1]
|
|
|
|
|
|
def test_custom_domain_validation_get_expected_mx_records_domain_from_partner():
|
|
domain = random_domain()
|
|
custom_domain = create_custom_domain(domain)
|
|
|
|
partner_id = get_proton_partner().id
|
|
custom_domain.partner_id = partner_id
|
|
Session.commit()
|
|
|
|
dkim_domain = random_domain()
|
|
validator = CustomDomainValidation(dkim_domain)
|
|
records = validator.get_expected_mx_records(custom_domain)
|
|
# As the domain is a partner_domain but there is no custom config for partner, default records
|
|
# should be used
|
|
assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY)
|
|
for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)):
|
|
config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i]
|
|
assert records[i].priority == config_record[0]
|
|
assert records[i].domain == config_record[1]
|
|
|
|
|
|
def test_custom_domain_validation_get_expected_mx_records_domain_from_partner_with_custom_config():
|
|
domain = random_domain()
|
|
custom_domain = create_custom_domain(domain)
|
|
|
|
partner_id = get_proton_partner().id
|
|
custom_domain.partner_id = partner_id
|
|
Session.commit()
|
|
|
|
dkim_domain = random_domain()
|
|
expected_mx_domain = random_domain()
|
|
validator = CustomDomainValidation(
|
|
dkim_domain, partner_domains={partner_id: expected_mx_domain}
|
|
)
|
|
records = validator.get_expected_mx_records(custom_domain)
|
|
# As the domain is a partner_domain and there is a custom config for partner, partner records
|
|
# should be used
|
|
assert len(records) == 2
|
|
|
|
assert records[0].priority == 10
|
|
assert records[0].domain == f"mx1.{expected_mx_domain}."
|
|
assert records[1].priority == 20
|
|
assert records[1].domain == f"mx2.{expected_mx_domain}."
|
|
|
|
|
|
# get_expected_spf_records
|
|
def test_custom_domain_validation_get_expected_spf_record_regular_domain():
|
|
domain = random_domain()
|
|
custom_domain = create_custom_domain(domain)
|
|
|
|
partner_id = get_proton_partner().id
|
|
|
|
dkim_domain = random_domain()
|
|
validator = CustomDomainValidation(
|
|
domain, partner_domains={partner_id: dkim_domain}
|
|
)
|
|
record = validator.get_expected_spf_record(custom_domain)
|
|
# As the domain is not a partner_domain, default records should be used even if
|
|
# there is a config for the partner
|
|
assert record == f"v=spf1 include:{config.EMAIL_DOMAIN} ~all"
|
|
|
|
|
|
def test_custom_domain_validation_get_expected_spf_record_domain_from_partner():
|
|
domain = random_domain()
|
|
custom_domain = create_custom_domain(domain)
|
|
|
|
partner_id = get_proton_partner().id
|
|
custom_domain.partner_id = partner_id
|
|
Session.commit()
|
|
|
|
dkim_domain = random_domain()
|
|
validator = CustomDomainValidation(dkim_domain)
|
|
record = validator.get_expected_spf_record(custom_domain)
|
|
# As the domain is a partner_domain but there is no custom config for partner, default records
|
|
# should be used
|
|
assert record == f"v=spf1 include:{config.EMAIL_DOMAIN} ~all"
|
|
|
|
|
|
def test_custom_domain_validation_get_expected_spf_record_domain_from_partner_with_custom_config():
|
|
domain = random_domain()
|
|
custom_domain = create_custom_domain(domain)
|
|
|
|
partner_id = get_proton_partner().id
|
|
custom_domain.partner_id = partner_id
|
|
Session.commit()
|
|
|
|
dkim_domain = random_domain()
|
|
expected_mx_domain = random_domain()
|
|
validator = CustomDomainValidation(
|
|
dkim_domain, partner_domains={partner_id: expected_mx_domain}
|
|
)
|
|
record = validator.get_expected_spf_record(custom_domain)
|
|
# As the domain is a partner_domain and there is a custom config for partner, partner records
|
|
# should be used
|
|
assert record == f"v=spf1 include:{expected_mx_domain} ~all"
|
|
|
|
|
|
# validate_dkim_records
|
|
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
res = validator.validate_dkim_records(domain)
|
|
|
|
assert len(res) == 3
|
|
for record_value in res.values():
|
|
assert record_value == "empty"
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.dkim_verified is False
|
|
|
|
|
|
def test_custom_domain_validation_validate_dkim_records_wrong_records_failure():
|
|
dkim_domain = random_domain()
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(dkim_domain, dns_client)
|
|
|
|
user_domain = random_domain()
|
|
|
|
# One domain right, two domains wrong
|
|
dns_client.set_cname_record(
|
|
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
|
|
)
|
|
dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong")
|
|
dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong")
|
|
|
|
domain = create_custom_domain(user_domain)
|
|
res = validator.validate_dkim_records(domain)
|
|
|
|
assert len(res) == 2
|
|
for record_value in res.values():
|
|
assert record_value == "wrong"
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.dkim_verified is False
|
|
|
|
|
|
def test_custom_domain_validation_validate_dkim_records_success_with_old_system():
|
|
dkim_domain = random_domain()
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(dkim_domain, dns_client)
|
|
|
|
user_domain = random_domain()
|
|
|
|
# One domain right, other domains missing
|
|
dns_client.set_cname_record(
|
|
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
|
|
)
|
|
|
|
domain = create_custom_domain(user_domain)
|
|
|
|
# DKIM is verified
|
|
domain.dkim_verified = True
|
|
Session.commit()
|
|
|
|
res = validator.validate_dkim_records(domain)
|
|
assert len(res) == 2
|
|
assert f"dkim02._domainkey.{user_domain}" in res
|
|
assert f"dkim03._domainkey.{user_domain}" in res
|
|
|
|
# Flag is not cleared
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.dkim_verified is True
|
|
|
|
|
|
def test_custom_domain_validation_validate_dkim_records_success():
|
|
dkim_domain = random_domain()
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(dkim_domain, dns_client)
|
|
|
|
user_domain = random_domain()
|
|
|
|
# One domain right, two domains wrong
|
|
dns_client.set_cname_record(
|
|
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
|
|
)
|
|
dns_client.set_cname_record(
|
|
f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}"
|
|
)
|
|
dns_client.set_cname_record(
|
|
f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}"
|
|
)
|
|
|
|
domain = create_custom_domain(user_domain)
|
|
res = validator.validate_dkim_records(domain)
|
|
assert len(res) == 0
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.dkim_verified is True
|
|
|
|
|
|
# validate_ownership
|
|
def test_custom_domain_validation_validate_ownership_empty_records_failure():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
res = validator.validate_domain_ownership(domain)
|
|
|
|
assert res.success is False
|
|
assert len(res.errors) == 0
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.ownership_verified is False
|
|
|
|
|
|
def test_custom_domain_validation_validate_ownership_wrong_records_failure():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
|
|
wrong_records = [random_string()]
|
|
dns_client.set_txt_record(domain.domain, wrong_records)
|
|
res = validator.validate_domain_ownership(domain)
|
|
|
|
assert res.success is False
|
|
assert res.errors == wrong_records
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.ownership_verified is False
|
|
|
|
|
|
def test_custom_domain_validation_validate_ownership_success():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
|
|
dns_client.set_txt_record(
|
|
domain.domain, [validator.get_ownership_verification_record(domain)]
|
|
)
|
|
res = validator.validate_domain_ownership(domain)
|
|
|
|
assert res.success is True
|
|
assert len(res.errors) == 0
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.ownership_verified is True
|
|
|
|
|
|
def test_custom_domain_validation_validate_ownership_from_partner_success():
|
|
dns_client = InMemoryDNSClient()
|
|
partner_id = get_proton_partner().id
|
|
|
|
prefix = random_string()
|
|
validator = CustomDomainValidation(
|
|
random_domain(),
|
|
dns_client,
|
|
partner_domains_validation_prefixes={partner_id: prefix},
|
|
)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
domain.partner_id = partner_id
|
|
Session.commit()
|
|
|
|
dns_client.set_txt_record(
|
|
domain.domain, [validator.get_ownership_verification_record(domain)]
|
|
)
|
|
res = validator.validate_domain_ownership(domain)
|
|
|
|
assert res.success is True
|
|
assert len(res.errors) == 0
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.ownership_verified is True
|
|
|
|
|
|
# validate_mx_records
|
|
def test_custom_domain_validation_validate_mx_records_empty_failure():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
res = validator.validate_mx_records(domain)
|
|
|
|
assert res.success is False
|
|
assert len(res.errors) == 0
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.verified is False
|
|
|
|
|
|
def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
|
|
wrong_record_1 = random_string()
|
|
wrong_record_2 = random_string()
|
|
wrong_records = [MxRecord(10, wrong_record_1), MxRecord(20, wrong_record_2)]
|
|
dns_client.set_mx_records(domain.domain, wrong_records)
|
|
res = validator.validate_mx_records(domain)
|
|
|
|
assert res.success is False
|
|
assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"]
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.verified is False
|
|
|
|
|
|
def test_custom_domain_validation_validate_mx_records_success():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
|
|
dns_client.set_mx_records(domain.domain, validator.get_expected_mx_records(domain))
|
|
res = validator.validate_mx_records(domain)
|
|
|
|
assert res.success is True
|
|
assert len(res.errors) == 0
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.verified is True
|
|
|
|
|
|
# validate_spf_records
|
|
def test_custom_domain_validation_validate_spf_records_empty_failure():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
res = validator.validate_spf_records(domain)
|
|
|
|
assert res.success is False
|
|
assert len(res.errors) == 0
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.spf_verified is False
|
|
|
|
|
|
def test_custom_domain_validation_validate_spf_records_wrong_records_failure():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
|
|
wrong_records = [random_string()]
|
|
dns_client.set_txt_record(domain.domain, wrong_records)
|
|
res = validator.validate_spf_records(domain)
|
|
|
|
assert res.success is False
|
|
assert res.errors == wrong_records
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.spf_verified is False
|
|
|
|
|
|
def test_custom_domain_validation_validate_spf_records_success():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
|
|
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"])
|
|
res = validator.validate_spf_records(domain)
|
|
|
|
assert res.success is True
|
|
assert len(res.errors) == 0
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.spf_verified is True
|
|
|
|
|
|
def test_custom_domain_validation_validate_spf_records_partner_domain_success():
|
|
dns_client = InMemoryDNSClient()
|
|
proton_partner_id = get_proton_partner().id
|
|
|
|
expected_domain = random_domain()
|
|
validator = CustomDomainValidation(
|
|
dkim_domain=random_domain(),
|
|
dns_client=dns_client,
|
|
partner_domains={proton_partner_id: expected_domain},
|
|
)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
domain.partner_id = proton_partner_id
|
|
Session.commit()
|
|
|
|
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{expected_domain}"])
|
|
res = validator.validate_spf_records(domain)
|
|
|
|
assert res.success is True
|
|
assert len(res.errors) == 0
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.spf_verified is True
|
|
|
|
|
|
def test_custom_domain_validation_validate_spf_cleans_verification_record():
|
|
dns_client = InMemoryDNSClient()
|
|
proton_partner_id = get_proton_partner().id
|
|
|
|
expected_domain = random_domain()
|
|
validator = CustomDomainValidation(
|
|
dkim_domain=random_domain(),
|
|
dns_client=dns_client,
|
|
partner_domains={proton_partner_id: expected_domain},
|
|
)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
domain.partner_id = proton_partner_id
|
|
Session.commit()
|
|
|
|
wrong_record = random_string()
|
|
dns_client.set_txt_record(
|
|
hostname=domain.domain,
|
|
txt_list=[wrong_record, validator.get_ownership_verification_record(domain)],
|
|
)
|
|
res = validator.validate_spf_records(domain)
|
|
|
|
assert res.success is False
|
|
assert len(res.errors) == 1
|
|
assert res.errors[0] == wrong_record
|
|
|
|
|
|
# validate_dmarc_records
|
|
def test_custom_domain_validation_validate_dmarc_records_empty_failure():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
res = validator.validate_dmarc_records(domain)
|
|
|
|
assert res.success is False
|
|
assert len(res.errors) == 0
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.dmarc_verified is False
|
|
|
|
|
|
def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
|
|
wrong_records = [random_string()]
|
|
dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records)
|
|
res = validator.validate_dmarc_records(domain)
|
|
|
|
assert res.success is False
|
|
assert res.errors == wrong_records
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.dmarc_verified is False
|
|
|
|
|
|
def test_custom_domain_validation_validate_dmarc_records_success():
|
|
dns_client = InMemoryDNSClient()
|
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
|
|
|
domain = create_custom_domain(random_domain())
|
|
|
|
dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD])
|
|
res = validator.validate_dmarc_records(domain)
|
|
|
|
assert res.success is True
|
|
assert len(res.errors) == 0
|
|
|
|
db_domain = CustomDomain.get_by(id=domain.id)
|
|
assert db_domain.dmarc_verified is True
|