mirror of
https://github.com/simple-login/app.git
synced 2025-02-24 07:43:54 +08:00
Refactor alias options and add it to more methods (#1681)
Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
This commit is contained in:
parent
43b91cd197
commit
b6f79ea3a6
4 changed files with 72 additions and 36 deletions
|
@ -6,7 +6,7 @@ from typing import Optional
|
|||
import itsdangerous
|
||||
from app import config
|
||||
from app.log import LOG
|
||||
from app.models import User, Partner
|
||||
from app.models import User, AliasOptions
|
||||
|
||||
signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET)
|
||||
|
||||
|
@ -42,7 +42,9 @@ def check_suffix_signature(signed_suffix: str) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
|
||||
def verify_prefix_suffix(
|
||||
user: User, alias_prefix, alias_suffix, alias_options: Optional[AliasOptions] = None
|
||||
) -> bool:
|
||||
"""verify if user could create an alias with the given prefix and suffix"""
|
||||
if not alias_prefix or not alias_suffix: # should be caught on frontend
|
||||
return False
|
||||
|
@ -63,7 +65,7 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
|
|||
# 1) alias_suffix must start with "." and
|
||||
# 2) alias_domain_prefix must come from the word list
|
||||
if (
|
||||
alias_domain in user.available_sl_domains()
|
||||
alias_domain in user.available_sl_domains(alias_options=alias_options)
|
||||
and alias_domain not in user_custom_domains
|
||||
# when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty
|
||||
and not config.DISABLE_ALIAS_SUFFIX
|
||||
|
@ -79,7 +81,9 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
|
|||
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
|
||||
return False
|
||||
|
||||
if alias_domain not in user.available_sl_domains():
|
||||
if alias_domain not in user.available_sl_domains(
|
||||
alias_options=alias_options
|
||||
):
|
||||
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
|
||||
return False
|
||||
|
||||
|
@ -87,9 +91,7 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
|
|||
|
||||
|
||||
def get_alias_suffixes(
|
||||
user: User,
|
||||
show_domains_for_partner: Optional[Partner] = None,
|
||||
show_sl_domains: bool = True,
|
||||
user: User, alias_options: Optional[AliasOptions] = None
|
||||
) -> [AliasSuffix]:
|
||||
"""
|
||||
Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up.
|
||||
|
@ -142,7 +144,7 @@ def get_alias_suffixes(
|
|||
alias_suffixes.append(alias_suffix)
|
||||
|
||||
# then SimpleLogin domain
|
||||
for sl_domain in user.get_sl_domains(show_domains_for_partner, show_sl_domains):
|
||||
for sl_domain in user.get_sl_domains(alias_options=alias_options):
|
||||
suffix = (
|
||||
(
|
||||
""
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import enum
|
||||
import hashlib
|
||||
import hmac
|
||||
|
@ -273,6 +274,12 @@ class IntEnumType(sa.types.TypeDecorator):
|
|||
return self._enum_type(enum_value)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AliasOptions:
|
||||
show_sl_domains: bool = True
|
||||
show_partner_domains: Optional[Partner] = None
|
||||
|
||||
|
||||
class Hibp(Base, ModelMixin):
|
||||
__tablename__ = "hibp"
|
||||
name = sa.Column(sa.String(), nullable=False, unique=True, index=True)
|
||||
|
@ -867,14 +874,16 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||
def custom_domains(self):
|
||||
return CustomDomain.filter_by(user_id=self.id, verified=True).all()
|
||||
|
||||
def available_domains_for_random_alias(self) -> List[Tuple[bool, str]]:
|
||||
def available_domains_for_random_alias(
|
||||
self, alias_options: Optional[AliasOptions] = None
|
||||
) -> List[Tuple[bool, str]]:
|
||||
"""Return available domains for user to create random aliases
|
||||
Each result record contains:
|
||||
- whether the domain belongs to SimpleLogin
|
||||
- the domain
|
||||
"""
|
||||
res = []
|
||||
for domain in self.available_sl_domains():
|
||||
for domain in self.available_sl_domains(alias_options=alias_options):
|
||||
res.append((True, domain))
|
||||
|
||||
for custom_domain in self.verified_custom_domains():
|
||||
|
@ -959,32 +968,37 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||
|
||||
return None, "", False
|
||||
|
||||
def available_sl_domains(self) -> [str]:
|
||||
def available_sl_domains(
|
||||
self, alias_options: Optional[AliasOptions] = None
|
||||
) -> [str]:
|
||||
"""
|
||||
Return all SimpleLogin domains that user can use when creating a new alias, including:
|
||||
- SimpleLogin public domains, available for all users (ALIAS_DOMAIN)
|
||||
- SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN)
|
||||
"""
|
||||
return [sl_domain.domain for sl_domain in self.get_sl_domains()]
|
||||
return [
|
||||
sl_domain.domain
|
||||
for sl_domain in self.get_sl_domains(alias_options=alias_options)
|
||||
]
|
||||
|
||||
def get_sl_domains(
|
||||
self,
|
||||
show_domains_for_partner: Optional[Partner] = None,
|
||||
show_sl_domains: bool = True,
|
||||
self, alias_options: Optional[AliasOptions] = None
|
||||
) -> list["SLDomain"]:
|
||||
if alias_options is None:
|
||||
alias_options = AliasOptions()
|
||||
conditions = [SLDomain.hidden == False] # noqa: E712
|
||||
if not self.is_premium():
|
||||
conditions.append(SLDomain.premium_only == False) # noqa: E712
|
||||
partner_domain_cond = [] # noqa:E711
|
||||
if show_domains_for_partner is not None:
|
||||
if alias_options.show_partner_domains is not None:
|
||||
partner_user = PartnerUser.filter_by(
|
||||
user_id=self.id, partner_id=show_domains_for_partner.id
|
||||
user_id=self.id, partner_id=alias_options.show_partner_domains.id
|
||||
).first()
|
||||
if partner_user is not None:
|
||||
partner_domain_cond.append(
|
||||
SLDomain.partner_id == partner_user.partner_id
|
||||
)
|
||||
if show_sl_domains:
|
||||
if alias_options.show_sl_domains:
|
||||
partner_domain_cond.append(SLDomain.partner_id == None) # noqa:E711
|
||||
if len(partner_domain_cond) == 1:
|
||||
conditions.append(partner_domain_cond[0])
|
||||
|
@ -993,14 +1007,16 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||
query = Session.query(SLDomain).filter(*conditions).order_by(SLDomain.order)
|
||||
return query.all()
|
||||
|
||||
def available_alias_domains(self) -> [str]:
|
||||
def available_alias_domains(
|
||||
self, alias_options: Optional[AliasOptions] = None
|
||||
) -> [str]:
|
||||
"""return all domains that user can use when creating a new alias, including:
|
||||
- SimpleLogin public domains, available for all users (ALIAS_DOMAIN)
|
||||
- SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN)
|
||||
- Verified custom domains
|
||||
|
||||
"""
|
||||
domains = self.available_sl_domains()
|
||||
domains = self.available_sl_domains(alias_options=alias_options)
|
||||
|
||||
for custom_domain in self.verified_custom_domains():
|
||||
domains.append(custom_domain.domain)
|
||||
|
|
|
@ -1,13 +1,7 @@
|
|||
from app import config
|
||||
from app.db import Session
|
||||
from app.models import User, Job
|
||||
from tests.utils import create_new_user, random_email
|
||||
|
||||
|
||||
def test_available_sl_domains(flask_client):
|
||||
user = create_new_user()
|
||||
|
||||
assert set(user.available_sl_domains()) == {"d1.test", "d2.test", "sl.local"}
|
||||
from tests.utils import random_email
|
||||
|
||||
|
||||
def test_create_from_partner(flask_client):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from app.db import Session
|
||||
from app.models import SLDomain, PartnerUser
|
||||
from app.models import SLDomain, PartnerUser, AliasOptions
|
||||
from app.proton.utils import get_proton_partner
|
||||
from init_app import add_sl_domains
|
||||
from tests.utils import create_new_user, random_token
|
||||
|
@ -43,12 +43,14 @@ def test_get_non_partner_domains():
|
|||
assert len(domains) == 2
|
||||
assert domains[0].domain == "premium_non_partner"
|
||||
assert domains[1].domain == "free_non_partner"
|
||||
assert [d.domain for d in domains] == user.available_sl_domains()
|
||||
# Free
|
||||
user.trial_end = None
|
||||
Session.flush()
|
||||
domains = user.get_sl_domains()
|
||||
assert len(domains) == 1
|
||||
assert domains[0].domain == "free_non_partner"
|
||||
assert [d.domain for d in domains] == user.available_sl_domains()
|
||||
|
||||
|
||||
def test_get_free_with_partner_domains():
|
||||
|
@ -64,17 +66,28 @@ def test_get_free_with_partner_domains():
|
|||
# Default
|
||||
assert len(domains) == 1
|
||||
assert domains[0].domain == "free_non_partner"
|
||||
assert [d.domain for d in domains] == user.available_sl_domains()
|
||||
# Show partner domains
|
||||
domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner())
|
||||
options = AliasOptions(
|
||||
show_sl_domains=True, show_partner_domains=get_proton_partner()
|
||||
)
|
||||
domains = user.get_sl_domains(alias_options=options)
|
||||
assert len(domains) == 2
|
||||
assert domains[0].domain == "free_partner"
|
||||
assert domains[1].domain == "free_non_partner"
|
||||
# Only partner domains
|
||||
domains = user.get_sl_domains(
|
||||
show_domains_for_partner=get_proton_partner(), show_sl_domains=False
|
||||
assert [d.domain for d in domains] == user.available_sl_domains(
|
||||
alias_options=options
|
||||
)
|
||||
# Only partner domains
|
||||
options = AliasOptions(
|
||||
show_sl_domains=False, show_partner_domains=get_proton_partner()
|
||||
)
|
||||
domains = user.get_sl_domains(alias_options=options)
|
||||
assert len(domains) == 1
|
||||
assert domains[0].domain == "free_partner"
|
||||
assert [d.domain for d in domains] == user.available_sl_domains(
|
||||
alias_options=options
|
||||
)
|
||||
|
||||
|
||||
def test_get_premium_with_partner_domains():
|
||||
|
@ -90,17 +103,28 @@ def test_get_premium_with_partner_domains():
|
|||
assert len(domains) == 2
|
||||
assert domains[0].domain == "premium_non_partner"
|
||||
assert domains[1].domain == "free_non_partner"
|
||||
assert [d.domain for d in domains] == user.available_sl_domains()
|
||||
# Show partner domains
|
||||
domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner())
|
||||
options = AliasOptions(
|
||||
show_sl_domains=True, show_partner_domains=get_proton_partner()
|
||||
)
|
||||
domains = user.get_sl_domains(alias_options=options)
|
||||
assert len(domains) == 4
|
||||
assert domains[0].domain == "premium_partner"
|
||||
assert domains[1].domain == "free_partner"
|
||||
assert domains[2].domain == "premium_non_partner"
|
||||
assert domains[3].domain == "free_non_partner"
|
||||
# Only partner domains
|
||||
domains = user.get_sl_domains(
|
||||
show_domains_for_partner=get_proton_partner(), show_sl_domains=False
|
||||
assert [d.domain for d in domains] == user.available_sl_domains(
|
||||
alias_options=options
|
||||
)
|
||||
# Only partner domains
|
||||
options = AliasOptions(
|
||||
show_sl_domains=False, show_partner_domains=get_proton_partner()
|
||||
)
|
||||
domains = user.get_sl_domains(alias_options=options)
|
||||
assert len(domains) == 2
|
||||
assert domains[0].domain == "premium_partner"
|
||||
assert domains[1].domain == "free_partner"
|
||||
assert [d.domain for d in domains] == user.available_sl_domains(
|
||||
alias_options=options
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue