mirror of
https://github.com/simple-login/app.git
synced 2024-11-15 05:07:33 +08:00
chore: DNS validation improvements (#2248)
* chore: DNS validation improvements * fix: do not show domains pending deletion * fix: generate verification token if null * revert: dmarc cleanup
This commit is contained in:
parent
06ab116476
commit
9d5697b624
4 changed files with 51 additions and 4 deletions
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from app import config
|
||||
from app.constants import DMARC_RECORD
|
||||
|
@ -11,6 +11,7 @@ from app.dns_utils import (
|
|||
get_network_dns_client,
|
||||
)
|
||||
from app.models import CustomDomain
|
||||
from app.utils import random_string
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -42,6 +43,11 @@ class CustomDomainValidation:
|
|||
and domain.partner_id in self._partner_domain_validation_prefixes
|
||||
):
|
||||
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
|
||||
|
||||
if not domain.ownership_txt_token:
|
||||
domain.ownership_txt_token = random_string(30)
|
||||
Session.commit()
|
||||
|
||||
return f"{prefix}-verification={domain.ownership_txt_token}"
|
||||
|
||||
def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]:
|
||||
|
@ -164,9 +170,11 @@ class CustomDomainValidation:
|
|||
else:
|
||||
custom_domain.spf_verified = False
|
||||
Session.commit()
|
||||
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
|
||||
cleaned_records = self.__clean_spf_records(txt_records, custom_domain)
|
||||
return DomainValidationResult(
|
||||
success=False,
|
||||
errors=self._dns_client.get_txt_record(custom_domain.domain),
|
||||
errors=cleaned_records,
|
||||
)
|
||||
|
||||
def validate_dmarc_records(
|
||||
|
@ -181,3 +189,13 @@ class CustomDomainValidation:
|
|||
custom_domain.dmarc_verified = False
|
||||
Session.commit()
|
||||
return DomainValidationResult(success=False, errors=txt_records)
|
||||
|
||||
def __clean_spf_records(
|
||||
self, txt_records: List[str], custom_domain: CustomDomain
|
||||
) -> List[str]:
|
||||
final_records = []
|
||||
verification_record = self.get_ownership_verification_record(custom_domain)
|
||||
for record in txt_records:
|
||||
if record != verification_record:
|
||||
final_records.append(record)
|
||||
return final_records
|
||||
|
|
|
@ -21,7 +21,9 @@ class NewCustomDomainForm(FlaskForm):
|
|||
@parallel_limiter.lock(only_when=lambda: request.method == "POST")
|
||||
def custom_domain():
|
||||
custom_domains = CustomDomain.filter_by(
|
||||
user_id=current_user.id, is_sl_subdomain=False
|
||||
user_id=current_user.id,
|
||||
is_sl_subdomain=False,
|
||||
pending_deletion=False,
|
||||
).all()
|
||||
new_custom_domain_form = NewCustomDomainForm()
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ class NetworkDNSClient(DNSClient):
|
|||
|
||||
def get_txt_record(self, hostname: str) -> List[str]:
|
||||
try:
|
||||
answers = self._resolver.resolve(hostname, "TXT", search=True)
|
||||
answers = self._resolver.resolve(hostname, "TXT", search=False)
|
||||
ret = []
|
||||
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
|
||||
for record in a.strings:
|
||||
|
|
|
@ -470,6 +470,33 @@ def test_custom_domain_validation_validate_spf_records_partner_domain_success():
|
|||
assert db_domain.spf_verified is True
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_spf_cleans_verification_record():
|
||||
dns_client = InMemoryDNSClient()
|
||||
proton_partner_id = get_proton_partner().id
|
||||
|
||||
expected_domain = random_domain()
|
||||
validator = CustomDomainValidation(
|
||||
dkim_domain=random_domain(),
|
||||
dns_client=dns_client,
|
||||
partner_domains={proton_partner_id: expected_domain},
|
||||
)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
domain.partner_id = proton_partner_id
|
||||
Session.commit()
|
||||
|
||||
wrong_record = random_string()
|
||||
dns_client.set_txt_record(
|
||||
hostname=domain.domain,
|
||||
txt_list=[wrong_record, validator.get_ownership_verification_record(domain)],
|
||||
)
|
||||
res = validator.validate_spf_records(domain)
|
||||
|
||||
assert res.success is False
|
||||
assert len(res.errors) == 1
|
||||
assert res.errors[0] == wrong_record
|
||||
|
||||
|
||||
# validate_dmarc_records
|
||||
def test_custom_domain_validation_validate_dmarc_records_empty_failure():
|
||||
dns_client = InMemoryDNSClient()
|
||||
|
|
Loading…
Reference in a new issue