mirror of
https://github.com/simple-login/app.git
synced 2024-11-10 17:35:27 +08:00
Create Partner only domains (#1665)
* Add Partner only domains * Add hidden domain to the test and revert to default domains after the tests * Send what to show in each call * Fix: Pass none instead of false * Removed flag from partnerusr --------- Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
This commit is contained in:
parent
03e5083d97
commit
43b91cd197
4 changed files with 201 additions and 36 deletions
|
@ -6,8 +6,7 @@ from typing import Optional
|
|||
import itsdangerous
|
||||
from app import config
|
||||
from app.log import LOG
|
||||
from app.models import User
|
||||
|
||||
from app.models import User, Partner
|
||||
|
||||
signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET)
|
||||
|
||||
|
@ -87,7 +86,11 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def get_alias_suffixes(user: User) -> [AliasSuffix]:
|
||||
def get_alias_suffixes(
|
||||
user: User,
|
||||
show_domains_for_partner: Optional[Partner] = None,
|
||||
show_sl_domains: bool = True,
|
||||
) -> [AliasSuffix]:
|
||||
"""
|
||||
Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up.
|
||||
"""
|
||||
|
@ -139,7 +142,7 @@ def get_alias_suffixes(user: User) -> [AliasSuffix]:
|
|||
alias_suffixes.append(alias_suffix)
|
||||
|
||||
# then SimpleLogin domain
|
||||
for sl_domain in user.get_sl_domains():
|
||||
for sl_domain in user.get_sl_domains(show_domains_for_partner, show_sl_domains):
|
||||
suffix = (
|
||||
(
|
||||
""
|
||||
|
|
|
@ -18,7 +18,7 @@ from flanker.addresslib import address
|
|||
from flask import url_for
|
||||
from flask_login import UserMixin
|
||||
from jinja2 import FileSystemLoader, Environment
|
||||
from sqlalchemy import orm
|
||||
from sqlalchemy import orm, or_
|
||||
from sqlalchemy import text, desc, CheckConstraint, Index, Column
|
||||
from sqlalchemy.dialects.postgresql import TSVECTOR
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
@ -967,13 +967,31 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||
"""
|
||||
return [sl_domain.domain for sl_domain in self.get_sl_domains()]
|
||||
|
||||
def get_sl_domains(self) -> List["SLDomain"]:
|
||||
query = SLDomain.filter_by(hidden=False).order_by(SLDomain.order)
|
||||
|
||||
if self.is_premium():
|
||||
return query.all()
|
||||
def get_sl_domains(
|
||||
self,
|
||||
show_domains_for_partner: Optional[Partner] = None,
|
||||
show_sl_domains: bool = True,
|
||||
) -> list["SLDomain"]:
|
||||
conditions = [SLDomain.hidden == False] # noqa: E712
|
||||
if not self.is_premium():
|
||||
conditions.append(SLDomain.premium_only == False) # noqa: E712
|
||||
partner_domain_cond = [] # noqa:E711
|
||||
if show_domains_for_partner is not None:
|
||||
partner_user = PartnerUser.filter_by(
|
||||
user_id=self.id, partner_id=show_domains_for_partner.id
|
||||
).first()
|
||||
if partner_user is not None:
|
||||
partner_domain_cond.append(
|
||||
SLDomain.partner_id == partner_user.partner_id
|
||||
)
|
||||
if show_sl_domains:
|
||||
partner_domain_cond.append(SLDomain.partner_id == None) # noqa:E711
|
||||
if len(partner_domain_cond) == 1:
|
||||
conditions.append(partner_domain_cond[0])
|
||||
else:
|
||||
return query.filter_by(premium_only=False).all()
|
||||
conditions.append(or_(*partner_domain_cond))
|
||||
query = Session.query(SLDomain).filter(*conditions).order_by(SLDomain.order)
|
||||
return query.all()
|
||||
|
||||
def available_alias_domains(self) -> [str]:
|
||||
"""return all domains that user can use when creating a new alias, including:
|
||||
|
@ -2768,6 +2786,31 @@ class Notification(Base, ModelMixin):
|
|||
)
|
||||
|
||||
|
||||
class Partner(Base, ModelMixin):
|
||||
__tablename__ = "partner"
|
||||
|
||||
name = sa.Column(sa.String(128), unique=True, nullable=False)
|
||||
contact_email = sa.Column(sa.String(128), unique=True, nullable=False)
|
||||
|
||||
@staticmethod
|
||||
def find_by_token(token: str) -> Optional[Partner]:
|
||||
hmaced = PartnerApiToken.hmac_token(token)
|
||||
res = (
|
||||
Session.query(Partner, PartnerApiToken)
|
||||
.filter(
|
||||
and_(
|
||||
PartnerApiToken.token == hmaced,
|
||||
Partner.id == PartnerApiToken.partner_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if res:
|
||||
partner, partner_api_token = res
|
||||
return partner
|
||||
return None
|
||||
|
||||
|
||||
class SLDomain(Base, ModelMixin):
|
||||
"""SimpleLogin domains"""
|
||||
|
||||
|
@ -2785,6 +2828,13 @@ class SLDomain(Base, ModelMixin):
|
|||
sa.Boolean, nullable=False, default=False, server_default="0"
|
||||
)
|
||||
|
||||
partner_id = sa.Column(
|
||||
sa.ForeignKey(Partner.id, ondelete="cascade"),
|
||||
nullable=True,
|
||||
default=None,
|
||||
sever_default="NULL",
|
||||
)
|
||||
|
||||
# if enabled, do not show this domain when user creates a custom alias
|
||||
hidden = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0")
|
||||
|
||||
|
@ -3231,31 +3281,6 @@ class ProviderComplaint(Base, ModelMixin):
|
|||
refused_email = orm.relationship(RefusedEmail, foreign_keys=[refused_email_id])
|
||||
|
||||
|
||||
class Partner(Base, ModelMixin):
|
||||
__tablename__ = "partner"
|
||||
|
||||
name = sa.Column(sa.String(128), unique=True, nullable=False)
|
||||
contact_email = sa.Column(sa.String(128), unique=True, nullable=False)
|
||||
|
||||
@staticmethod
|
||||
def find_by_token(token: str) -> Optional[Partner]:
|
||||
hmaced = PartnerApiToken.hmac_token(token)
|
||||
res = (
|
||||
Session.query(Partner, PartnerApiToken)
|
||||
.filter(
|
||||
and_(
|
||||
PartnerApiToken.token == hmaced,
|
||||
Partner.id == PartnerApiToken.partner_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if res:
|
||||
partner, partner_api_token = res
|
||||
return partner
|
||||
return None
|
||||
|
||||
|
||||
class PartnerApiToken(Base, ModelMixin):
|
||||
__tablename__ = "partner_api_token"
|
||||
|
||||
|
|
31
migrations/versions/2023_040318_5f4a5625da66_.py
Normal file
31
migrations/versions/2023_040318_5f4a5625da66_.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
"""empty message
|
||||
|
||||
Revision ID: 5f4a5625da66
|
||||
Revises: 2c2093c82bc0
|
||||
Create Date: 2023-04-03 18:30:46.488231
|
||||
|
||||
"""
|
||||
import sqlalchemy_utils
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '5f4a5625da66'
|
||||
down_revision = '2c2093c82bc0'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('public_domain', sa.Column('partner_id', sa.Integer(), nullable=True, sever_default='NULL'))
|
||||
op.create_foreign_key(None, 'public_domain', 'partner', ['partner_id'], ['id'], ondelete='cascade')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(None, 'public_domain', type_='foreignkey')
|
||||
op.drop_column('public_domain', 'partner_id')
|
||||
# ### end Alembic commands ###
|
106
tests/test_domains.py
Normal file
106
tests/test_domains.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
from app.db import Session
|
||||
from app.models import SLDomain, PartnerUser
|
||||
from app.proton.utils import get_proton_partner
|
||||
from init_app import add_sl_domains
|
||||
from tests.utils import create_new_user, random_token
|
||||
|
||||
|
||||
def setup_module():
|
||||
Session.query(SLDomain).delete()
|
||||
SLDomain.create(
|
||||
domain="hidden", premium_only=False, flush=True, order=5, hidden=True
|
||||
)
|
||||
SLDomain.create(domain="free_non_partner", premium_only=False, flush=True, order=4)
|
||||
SLDomain.create(
|
||||
domain="premium_non_partner", premium_only=True, flush=True, order=3
|
||||
)
|
||||
SLDomain.create(
|
||||
domain="free_partner",
|
||||
premium_only=False,
|
||||
flush=True,
|
||||
partner_id=get_proton_partner().id,
|
||||
order=2,
|
||||
)
|
||||
SLDomain.create(
|
||||
domain="premium_partner",
|
||||
premium_only=True,
|
||||
flush=True,
|
||||
partner_id=get_proton_partner().id,
|
||||
order=1,
|
||||
)
|
||||
Session.commit()
|
||||
|
||||
|
||||
def teardown_module():
|
||||
Session.query(SLDomain).delete()
|
||||
add_sl_domains()
|
||||
|
||||
|
||||
def test_get_non_partner_domains():
|
||||
user = create_new_user()
|
||||
domains = user.get_sl_domains()
|
||||
# Premium
|
||||
assert len(domains) == 2
|
||||
assert domains[0].domain == "premium_non_partner"
|
||||
assert domains[1].domain == "free_non_partner"
|
||||
# Free
|
||||
user.trial_end = None
|
||||
Session.flush()
|
||||
domains = user.get_sl_domains()
|
||||
assert len(domains) == 1
|
||||
assert domains[0].domain == "free_non_partner"
|
||||
|
||||
|
||||
def test_get_free_with_partner_domains():
|
||||
user = create_new_user()
|
||||
user.trial_end = None
|
||||
PartnerUser.create(
|
||||
partner_id=get_proton_partner().id,
|
||||
user_id=user.id,
|
||||
external_user_id=random_token(10),
|
||||
flush=True,
|
||||
)
|
||||
domains = user.get_sl_domains()
|
||||
# Default
|
||||
assert len(domains) == 1
|
||||
assert domains[0].domain == "free_non_partner"
|
||||
# Show partner domains
|
||||
domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner())
|
||||
assert len(domains) == 2
|
||||
assert domains[0].domain == "free_partner"
|
||||
assert domains[1].domain == "free_non_partner"
|
||||
# Only partner domains
|
||||
domains = user.get_sl_domains(
|
||||
show_domains_for_partner=get_proton_partner(), show_sl_domains=False
|
||||
)
|
||||
assert len(domains) == 1
|
||||
assert domains[0].domain == "free_partner"
|
||||
|
||||
|
||||
def test_get_premium_with_partner_domains():
|
||||
user = create_new_user()
|
||||
PartnerUser.create(
|
||||
partner_id=get_proton_partner().id,
|
||||
user_id=user.id,
|
||||
external_user_id=random_token(10),
|
||||
flush=True,
|
||||
)
|
||||
domains = user.get_sl_domains()
|
||||
# Default
|
||||
assert len(domains) == 2
|
||||
assert domains[0].domain == "premium_non_partner"
|
||||
assert domains[1].domain == "free_non_partner"
|
||||
# Show partner domains
|
||||
domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner())
|
||||
assert len(domains) == 4
|
||||
assert domains[0].domain == "premium_partner"
|
||||
assert domains[1].domain == "free_partner"
|
||||
assert domains[2].domain == "premium_non_partner"
|
||||
assert domains[3].domain == "free_non_partner"
|
||||
# Only partner domains
|
||||
domains = user.get_sl_domains(
|
||||
show_domains_for_partner=get_proton_partner(), show_sl_domains=False
|
||||
)
|
||||
assert len(domains) == 2
|
||||
assert domains[0].domain == "premium_partner"
|
||||
assert domains[1].domain == "free_partner"
|
Loading…
Reference in a new issue