Refactor alias options and add it to more methods (#1681)

Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
This commit is contained in:
Adrià Casajús 2023-04-06 11:07:13 +02:00 committed by GitHub
parent 43b91cd197
commit b6f79ea3a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 36 deletions

View file

@ -6,7 +6,7 @@ from typing import Optional
import itsdangerous import itsdangerous
from app import config from app import config
from app.log import LOG from app.log import LOG
from app.models import User, Partner from app.models import User, AliasOptions
signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET) signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET)
@ -42,7 +42,9 @@ def check_suffix_signature(signed_suffix: str) -> Optional[str]:
return None return None
def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool: def verify_prefix_suffix(
user: User, alias_prefix, alias_suffix, alias_options: Optional[AliasOptions] = None
) -> bool:
"""verify if user could create an alias with the given prefix and suffix""" """verify if user could create an alias with the given prefix and suffix"""
if not alias_prefix or not alias_suffix: # should be caught on frontend if not alias_prefix or not alias_suffix: # should be caught on frontend
return False return False
@ -63,7 +65,7 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
# 1) alias_suffix must start with "." and # 1) alias_suffix must start with "." and
# 2) alias_domain_prefix must come from the word list # 2) alias_domain_prefix must come from the word list
if ( if (
alias_domain in user.available_sl_domains() alias_domain in user.available_sl_domains(alias_options=alias_options)
and alias_domain not in user_custom_domains and alias_domain not in user_custom_domains
# when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty # when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty
and not config.DISABLE_ALIAS_SUFFIX and not config.DISABLE_ALIAS_SUFFIX
@ -79,7 +81,9 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False return False
if alias_domain not in user.available_sl_domains(): if alias_domain not in user.available_sl_domains(
alias_options=alias_options
):
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False return False
@ -87,9 +91,7 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
def get_alias_suffixes( def get_alias_suffixes(
user: User, user: User, alias_options: Optional[AliasOptions] = None
show_domains_for_partner: Optional[Partner] = None,
show_sl_domains: bool = True,
) -> [AliasSuffix]: ) -> [AliasSuffix]:
""" """
Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up. Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up.
@ -142,7 +144,7 @@ def get_alias_suffixes(
alias_suffixes.append(alias_suffix) alias_suffixes.append(alias_suffix)
# then SimpleLogin domain # then SimpleLogin domain
for sl_domain in user.get_sl_domains(show_domains_for_partner, show_sl_domains): for sl_domain in user.get_sl_domains(alias_options=alias_options):
suffix = ( suffix = (
( (
"" ""

View file

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import base64 import base64
import dataclasses
import enum import enum
import hashlib import hashlib
import hmac import hmac
@ -273,6 +274,12 @@ class IntEnumType(sa.types.TypeDecorator):
return self._enum_type(enum_value) return self._enum_type(enum_value)
@dataclasses.dataclass
class AliasOptions:
show_sl_domains: bool = True
show_partner_domains: Optional[Partner] = None
class Hibp(Base, ModelMixin): class Hibp(Base, ModelMixin):
__tablename__ = "hibp" __tablename__ = "hibp"
name = sa.Column(sa.String(), nullable=False, unique=True, index=True) name = sa.Column(sa.String(), nullable=False, unique=True, index=True)
@ -867,14 +874,16 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
def custom_domains(self): def custom_domains(self):
return CustomDomain.filter_by(user_id=self.id, verified=True).all() return CustomDomain.filter_by(user_id=self.id, verified=True).all()
def available_domains_for_random_alias(self) -> List[Tuple[bool, str]]: def available_domains_for_random_alias(
self, alias_options: Optional[AliasOptions] = None
) -> List[Tuple[bool, str]]:
"""Return available domains for user to create random aliases """Return available domains for user to create random aliases
Each result record contains: Each result record contains:
- whether the domain belongs to SimpleLogin - whether the domain belongs to SimpleLogin
- the domain - the domain
""" """
res = [] res = []
for domain in self.available_sl_domains(): for domain in self.available_sl_domains(alias_options=alias_options):
res.append((True, domain)) res.append((True, domain))
for custom_domain in self.verified_custom_domains(): for custom_domain in self.verified_custom_domains():
@ -959,32 +968,37 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
return None, "", False return None, "", False
def available_sl_domains(self) -> [str]: def available_sl_domains(
self, alias_options: Optional[AliasOptions] = None
) -> [str]:
""" """
Return all SimpleLogin domains that user can use when creating a new alias, including: Return all SimpleLogin domains that user can use when creating a new alias, including:
- SimpleLogin public domains, available for all users (ALIAS_DOMAIN) - SimpleLogin public domains, available for all users (ALIAS_DOMAIN)
- SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN) - SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN)
""" """
return [sl_domain.domain for sl_domain in self.get_sl_domains()] return [
sl_domain.domain
for sl_domain in self.get_sl_domains(alias_options=alias_options)
]
def get_sl_domains( def get_sl_domains(
self, self, alias_options: Optional[AliasOptions] = None
show_domains_for_partner: Optional[Partner] = None,
show_sl_domains: bool = True,
) -> list["SLDomain"]: ) -> list["SLDomain"]:
if alias_options is None:
alias_options = AliasOptions()
conditions = [SLDomain.hidden == False] # noqa: E712 conditions = [SLDomain.hidden == False] # noqa: E712
if not self.is_premium(): if not self.is_premium():
conditions.append(SLDomain.premium_only == False) # noqa: E712 conditions.append(SLDomain.premium_only == False) # noqa: E712
partner_domain_cond = [] # noqa:E711 partner_domain_cond = [] # noqa:E711
if show_domains_for_partner is not None: if alias_options.show_partner_domains is not None:
partner_user = PartnerUser.filter_by( partner_user = PartnerUser.filter_by(
user_id=self.id, partner_id=show_domains_for_partner.id user_id=self.id, partner_id=alias_options.show_partner_domains.id
).first() ).first()
if partner_user is not None: if partner_user is not None:
partner_domain_cond.append( partner_domain_cond.append(
SLDomain.partner_id == partner_user.partner_id SLDomain.partner_id == partner_user.partner_id
) )
if show_sl_domains: if alias_options.show_sl_domains:
partner_domain_cond.append(SLDomain.partner_id == None) # noqa:E711 partner_domain_cond.append(SLDomain.partner_id == None) # noqa:E711
if len(partner_domain_cond) == 1: if len(partner_domain_cond) == 1:
conditions.append(partner_domain_cond[0]) conditions.append(partner_domain_cond[0])
@ -993,14 +1007,16 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
query = Session.query(SLDomain).filter(*conditions).order_by(SLDomain.order) query = Session.query(SLDomain).filter(*conditions).order_by(SLDomain.order)
return query.all() return query.all()
def available_alias_domains(self) -> [str]: def available_alias_domains(
self, alias_options: Optional[AliasOptions] = None
) -> [str]:
"""return all domains that user can use when creating a new alias, including: """return all domains that user can use when creating a new alias, including:
- SimpleLogin public domains, available for all users (ALIAS_DOMAIN) - SimpleLogin public domains, available for all users (ALIAS_DOMAIN)
- SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN) - SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN)
- Verified custom domains - Verified custom domains
""" """
domains = self.available_sl_domains() domains = self.available_sl_domains(alias_options=alias_options)
for custom_domain in self.verified_custom_domains(): for custom_domain in self.verified_custom_domains():
domains.append(custom_domain.domain) domains.append(custom_domain.domain)

View file

@ -1,13 +1,7 @@
from app import config from app import config
from app.db import Session from app.db import Session
from app.models import User, Job from app.models import User, Job
from tests.utils import create_new_user, random_email from tests.utils import random_email
def test_available_sl_domains(flask_client):
user = create_new_user()
assert set(user.available_sl_domains()) == {"d1.test", "d2.test", "sl.local"}
def test_create_from_partner(flask_client): def test_create_from_partner(flask_client):

View file

@ -1,5 +1,5 @@
from app.db import Session from app.db import Session
from app.models import SLDomain, PartnerUser from app.models import SLDomain, PartnerUser, AliasOptions
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from init_app import add_sl_domains from init_app import add_sl_domains
from tests.utils import create_new_user, random_token from tests.utils import create_new_user, random_token
@ -43,12 +43,14 @@ def test_get_non_partner_domains():
assert len(domains) == 2 assert len(domains) == 2
assert domains[0].domain == "premium_non_partner" assert domains[0].domain == "premium_non_partner"
assert domains[1].domain == "free_non_partner" assert domains[1].domain == "free_non_partner"
assert [d.domain for d in domains] == user.available_sl_domains()
# Free # Free
user.trial_end = None user.trial_end = None
Session.flush() Session.flush()
domains = user.get_sl_domains() domains = user.get_sl_domains()
assert len(domains) == 1 assert len(domains) == 1
assert domains[0].domain == "free_non_partner" assert domains[0].domain == "free_non_partner"
assert [d.domain for d in domains] == user.available_sl_domains()
def test_get_free_with_partner_domains(): def test_get_free_with_partner_domains():
@ -64,17 +66,28 @@ def test_get_free_with_partner_domains():
# Default # Default
assert len(domains) == 1 assert len(domains) == 1
assert domains[0].domain == "free_non_partner" assert domains[0].domain == "free_non_partner"
assert [d.domain for d in domains] == user.available_sl_domains()
# Show partner domains # Show partner domains
domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner()) options = AliasOptions(
show_sl_domains=True, show_partner_domains=get_proton_partner()
)
domains = user.get_sl_domains(alias_options=options)
assert len(domains) == 2 assert len(domains) == 2
assert domains[0].domain == "free_partner" assert domains[0].domain == "free_partner"
assert domains[1].domain == "free_non_partner" assert domains[1].domain == "free_non_partner"
# Only partner domains assert [d.domain for d in domains] == user.available_sl_domains(
domains = user.get_sl_domains( alias_options=options
show_domains_for_partner=get_proton_partner(), show_sl_domains=False
) )
# Only partner domains
options = AliasOptions(
show_sl_domains=False, show_partner_domains=get_proton_partner()
)
domains = user.get_sl_domains(alias_options=options)
assert len(domains) == 1 assert len(domains) == 1
assert domains[0].domain == "free_partner" assert domains[0].domain == "free_partner"
assert [d.domain for d in domains] == user.available_sl_domains(
alias_options=options
)
def test_get_premium_with_partner_domains(): def test_get_premium_with_partner_domains():
@ -90,17 +103,28 @@ def test_get_premium_with_partner_domains():
assert len(domains) == 2 assert len(domains) == 2
assert domains[0].domain == "premium_non_partner" assert domains[0].domain == "premium_non_partner"
assert domains[1].domain == "free_non_partner" assert domains[1].domain == "free_non_partner"
assert [d.domain for d in domains] == user.available_sl_domains()
# Show partner domains # Show partner domains
domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner()) options = AliasOptions(
show_sl_domains=True, show_partner_domains=get_proton_partner()
)
domains = user.get_sl_domains(alias_options=options)
assert len(domains) == 4 assert len(domains) == 4
assert domains[0].domain == "premium_partner" assert domains[0].domain == "premium_partner"
assert domains[1].domain == "free_partner" assert domains[1].domain == "free_partner"
assert domains[2].domain == "premium_non_partner" assert domains[2].domain == "premium_non_partner"
assert domains[3].domain == "free_non_partner" assert domains[3].domain == "free_non_partner"
# Only partner domains assert [d.domain for d in domains] == user.available_sl_domains(
domains = user.get_sl_domains( alias_options=options
show_domains_for_partner=get_proton_partner(), show_sl_domains=False
) )
# Only partner domains
options = AliasOptions(
show_sl_domains=False, show_partner_domains=get_proton_partner()
)
domains = user.get_sl_domains(alias_options=options)
assert len(domains) == 2 assert len(domains) == 2
assert domains[0].domain == "premium_partner" assert domains[0].domain == "premium_partner"
assert domains[1].domain == "free_partner" assert domains[1].domain == "free_partner"
assert [d.domain for d in domains] == user.available_sl_domains(
alias_options=options
)