diff --git a/app/config.py b/app/config.py index 9ec2cb65..7d6ff975 100644 --- a/app/config.py +++ b/app/config.py @@ -62,6 +62,17 @@ def get_env_dict(env_var: str) -> dict[str, str]: return result +def get_env_csv(env_var: str, default: Optional[str]) -> list[str]: + """ + Get an env variable and convert it into a list of strings separated by, + Syntax is: val1,val2 + """ + value = os.getenv(env_var, default) + if not value: + return [] + return [field.strip() for field in value.split(",") if field.strip()] + + config_file = os.environ.get("CONFIG") if config_file: config_file = get_abs_path(config_file) @@ -171,6 +182,14 @@ FIRST_ALIAS_DOMAIN = os.environ.get("FIRST_ALIAS_DOMAIN") or EMAIL_DOMAIN # e.g. [(10, "mx1.hostname."), (10, "mx2.hostname.")] EMAIL_SERVERS_WITH_PRIORITY = sl_getenv("EMAIL_SERVERS_WITH_PRIORITY") +PROTON_MX_SERVERS = get_env_csv( + "PROTON_MX_SERVERS", "mail.protonmail.ch., mailsec.protonmail.ch." +) + +PROTON_EMAIL_DOMAINS = get_env_csv( + "PROTON_EMAIL_DOMAINS", "proton.me, protonmail.com, protonmail.ch, proton.ch, pm.me" +) + # disable the alias suffix, i.e. the ".random_word" part DISABLE_ALIAS_SUFFIX = "DISABLE_ALIAS_SUFFIX" in os.environ diff --git a/app/dns_utils.py b/app/dns_utils.py index 02ca7841..d2fbd853 100644 --- a/app/dns_utils.py +++ b/app/dns_utils.py @@ -115,9 +115,20 @@ class InMemoryDNSClient(DNSClient): return self.txt_records.get(hostname, []) -def get_network_dns_client() -> NetworkDNSClient: +global_dns_client: Optional[DNSClient] = None + + +def get_network_dns_client() -> DNSClient: + global global_dns_client + if global_dns_client is not None: + return global_dns_client return NetworkDNSClient(NAMESERVERS) +def set_global_dns_client(dns_client: Optional[DNSClient]): + global global_dns_client + global_dns_client = dns_client + + def get_mx_domains(hostname: str) -> dict[int, list[str]]: return get_network_dns_client().get_mx_domains(hostname) diff --git a/app/models.py b/app/models.py index 2f572936..3032ba55 100644 --- a/app/models.py +++ b/app/models.py @@ -2838,24 +2838,20 @@ class Mailbox(Base, ModelMixin): return len(alias_ids) def is_proton(self) -> bool: - if ( - self.email.endswith("@proton.me") - or self.email.endswith("@protonmail.com") - or self.email.endswith("@protonmail.ch") - or self.email.endswith("@proton.ch") - or self.email.endswith("@pm.me") - ): - return True + for proton_email_domain in config.PROTON_EMAIL_DOMAINS: + if self.email.endswith(f"@{proton_email_domain}"): + return True from app.email_utils import get_email_local_part mx_domains = get_mx_domains(get_email_local_part(self.email)) + + proton_mx_domains = config.PROTON_MX_SERVERS # Proton is the first domain - if mx_domains and mx_domains[0].domain in ( - "mail.protonmail.ch.", - "mailsec.protonmail.ch.", - ): - return True + for prio in mx_domains: + for mx_domain in mx_domains[prio]: + if mx_domain in proton_mx_domains: + return True return False diff --git a/tests/models/test_mailbox.py b/tests/models/test_mailbox.py new file mode 100644 index 00000000..68f5d1fe --- /dev/null +++ b/tests/models/test_mailbox.py @@ -0,0 +1,37 @@ +from app import config +from app.dns_utils import set_global_dns_client, InMemoryDNSClient +from app.email_utils import get_email_local_part +from app.models import Mailbox +from tests.utils import create_new_user, random_email + +dns_client = InMemoryDNSClient() + + +def setup_module(): + set_global_dns_client(dns_client) + + +def teardown_module(): + set_global_dns_client(None) + + +def test_is_proton_with_email_domain(): + user = create_new_user() + mailbox = Mailbox.create( + user_id=user.id, email=f"test@{config.PROTON_EMAIL_DOMAINS[0]}" + ) + assert mailbox.is_proton() + mailbox = Mailbox.create(user_id=user.id, email="a@b.c") + assert not mailbox.is_proton() + + +def test_is_proton_with_mx_domain(): + email = random_email() + dns_client.set_mx_records( + get_email_local_part(email), {10: config.PROTON_MX_SERVERS} + ) + user = create_new_user() + mailbox = Mailbox.create(user_id=user.id, email=email) + assert mailbox.is_proton() + dns_client.set_mx_records(get_email_local_part(email), {10: ["nowhere.net"]}) + assert not mailbox.is_proton()