Create Partner only domains (#1665)

* Add Partner only domains

* Add hidden domain to the test and revert to default domains after the tests

* Send what to show in each call

* Fix: Pass none instead of false

* Removed flag from partnerusr

---------

Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
This commit is contained in:
Adrià Casajús 2023-04-04 15:21:51 +02:00 committed by GitHub
parent 03e5083d97
commit 43b91cd197
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 201 additions and 36 deletions

View file

@ -6,8 +6,7 @@ from typing import Optional
import itsdangerous
from app import config
from app.log import LOG
from app.models import User
from app.models import User, Partner
signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET)
@ -87,7 +86,11 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
return True
def get_alias_suffixes(user: User) -> [AliasSuffix]:
def get_alias_suffixes(
user: User,
show_domains_for_partner: Optional[Partner] = None,
show_sl_domains: bool = True,
) -> [AliasSuffix]:
"""
Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up.
"""
@ -139,7 +142,7 @@ def get_alias_suffixes(user: User) -> [AliasSuffix]:
alias_suffixes.append(alias_suffix)
# then SimpleLogin domain
for sl_domain in user.get_sl_domains():
for sl_domain in user.get_sl_domains(show_domains_for_partner, show_sl_domains):
suffix = (
(
""

View file

@ -18,7 +18,7 @@ from flanker.addresslib import address
from flask import url_for
from flask_login import UserMixin
from jinja2 import FileSystemLoader, Environment
from sqlalchemy import orm
from sqlalchemy import orm, or_
from sqlalchemy import text, desc, CheckConstraint, Index, Column
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.ext.declarative import declarative_base
@ -967,13 +967,31 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
"""
return [sl_domain.domain for sl_domain in self.get_sl_domains()]
def get_sl_domains(self) -> List["SLDomain"]:
query = SLDomain.filter_by(hidden=False).order_by(SLDomain.order)
if self.is_premium():
return query.all()
def get_sl_domains(
self,
show_domains_for_partner: Optional[Partner] = None,
show_sl_domains: bool = True,
) -> list["SLDomain"]:
conditions = [SLDomain.hidden == False] # noqa: E712
if not self.is_premium():
conditions.append(SLDomain.premium_only == False) # noqa: E712
partner_domain_cond = [] # noqa:E711
if show_domains_for_partner is not None:
partner_user = PartnerUser.filter_by(
user_id=self.id, partner_id=show_domains_for_partner.id
).first()
if partner_user is not None:
partner_domain_cond.append(
SLDomain.partner_id == partner_user.partner_id
)
if show_sl_domains:
partner_domain_cond.append(SLDomain.partner_id == None) # noqa:E711
if len(partner_domain_cond) == 1:
conditions.append(partner_domain_cond[0])
else:
return query.filter_by(premium_only=False).all()
conditions.append(or_(*partner_domain_cond))
query = Session.query(SLDomain).filter(*conditions).order_by(SLDomain.order)
return query.all()
def available_alias_domains(self) -> [str]:
"""return all domains that user can use when creating a new alias, including:
@ -2768,6 +2786,31 @@ class Notification(Base, ModelMixin):
)
class Partner(Base, ModelMixin):
__tablename__ = "partner"
name = sa.Column(sa.String(128), unique=True, nullable=False)
contact_email = sa.Column(sa.String(128), unique=True, nullable=False)
@staticmethod
def find_by_token(token: str) -> Optional[Partner]:
hmaced = PartnerApiToken.hmac_token(token)
res = (
Session.query(Partner, PartnerApiToken)
.filter(
and_(
PartnerApiToken.token == hmaced,
Partner.id == PartnerApiToken.partner_id,
)
)
.first()
)
if res:
partner, partner_api_token = res
return partner
return None
class SLDomain(Base, ModelMixin):
"""SimpleLogin domains"""
@ -2785,6 +2828,13 @@ class SLDomain(Base, ModelMixin):
sa.Boolean, nullable=False, default=False, server_default="0"
)
partner_id = sa.Column(
sa.ForeignKey(Partner.id, ondelete="cascade"),
nullable=True,
default=None,
sever_default="NULL",
)
# if enabled, do not show this domain when user creates a custom alias
hidden = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0")
@ -3231,31 +3281,6 @@ class ProviderComplaint(Base, ModelMixin):
refused_email = orm.relationship(RefusedEmail, foreign_keys=[refused_email_id])
class Partner(Base, ModelMixin):
__tablename__ = "partner"
name = sa.Column(sa.String(128), unique=True, nullable=False)
contact_email = sa.Column(sa.String(128), unique=True, nullable=False)
@staticmethod
def find_by_token(token: str) -> Optional[Partner]:
hmaced = PartnerApiToken.hmac_token(token)
res = (
Session.query(Partner, PartnerApiToken)
.filter(
and_(
PartnerApiToken.token == hmaced,
Partner.id == PartnerApiToken.partner_id,
)
)
.first()
)
if res:
partner, partner_api_token = res
return partner
return None
class PartnerApiToken(Base, ModelMixin):
__tablename__ = "partner_api_token"

View file

@ -0,0 +1,31 @@
"""empty message
Revision ID: 5f4a5625da66
Revises: 2c2093c82bc0
Create Date: 2023-04-03 18:30:46.488231
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '5f4a5625da66'
down_revision = '2c2093c82bc0'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('public_domain', sa.Column('partner_id', sa.Integer(), nullable=True, sever_default='NULL'))
op.create_foreign_key(None, 'public_domain', 'partner', ['partner_id'], ['id'], ondelete='cascade')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'public_domain', type_='foreignkey')
op.drop_column('public_domain', 'partner_id')
# ### end Alembic commands ###

106
tests/test_domains.py Normal file
View file

@ -0,0 +1,106 @@
from app.db import Session
from app.models import SLDomain, PartnerUser
from app.proton.utils import get_proton_partner
from init_app import add_sl_domains
from tests.utils import create_new_user, random_token
def setup_module():
Session.query(SLDomain).delete()
SLDomain.create(
domain="hidden", premium_only=False, flush=True, order=5, hidden=True
)
SLDomain.create(domain="free_non_partner", premium_only=False, flush=True, order=4)
SLDomain.create(
domain="premium_non_partner", premium_only=True, flush=True, order=3
)
SLDomain.create(
domain="free_partner",
premium_only=False,
flush=True,
partner_id=get_proton_partner().id,
order=2,
)
SLDomain.create(
domain="premium_partner",
premium_only=True,
flush=True,
partner_id=get_proton_partner().id,
order=1,
)
Session.commit()
def teardown_module():
Session.query(SLDomain).delete()
add_sl_domains()
def test_get_non_partner_domains():
user = create_new_user()
domains = user.get_sl_domains()
# Premium
assert len(domains) == 2
assert domains[0].domain == "premium_non_partner"
assert domains[1].domain == "free_non_partner"
# Free
user.trial_end = None
Session.flush()
domains = user.get_sl_domains()
assert len(domains) == 1
assert domains[0].domain == "free_non_partner"
def test_get_free_with_partner_domains():
user = create_new_user()
user.trial_end = None
PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
domains = user.get_sl_domains()
# Default
assert len(domains) == 1
assert domains[0].domain == "free_non_partner"
# Show partner domains
domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner())
assert len(domains) == 2
assert domains[0].domain == "free_partner"
assert domains[1].domain == "free_non_partner"
# Only partner domains
domains = user.get_sl_domains(
show_domains_for_partner=get_proton_partner(), show_sl_domains=False
)
assert len(domains) == 1
assert domains[0].domain == "free_partner"
def test_get_premium_with_partner_domains():
user = create_new_user()
PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
domains = user.get_sl_domains()
# Default
assert len(domains) == 2
assert domains[0].domain == "premium_non_partner"
assert domains[1].domain == "free_non_partner"
# Show partner domains
domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner())
assert len(domains) == 4
assert domains[0].domain == "premium_partner"
assert domains[1].domain == "free_partner"
assert domains[2].domain == "premium_non_partner"
assert domains[3].domain == "free_non_partner"
# Only partner domains
domains = user.get_sl_domains(
show_domains_for_partner=get_proton_partner(), show_sl_domains=False
)
assert len(domains) == 2
assert domains[0].domain == "premium_partner"
assert domains[1].domain == "free_partner"