Only send new alias events on user creation (#2285)

* Only send new alias events on user creation

(cherry picked from commit ab8f998dd40dc83c8f8a528a156ba50eae376aaf)

* Trigger a sync when a new partner user is created

* Improve tests

* Move it to the partner_utils

---------

Co-authored-by: Carlos Quintana <74399022+cquintana92@users.noreply.github.com>
This commit is contained in:
Adrià Casajús 2024-10-23 11:18:03 +02:00 committed by GitHub
parent f55ab58d0c
commit 9646f84b79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 53 additions and 20 deletions

View file

@ -112,6 +112,7 @@ def ensure_partner_user_exists_for_user(
partner_email=link_request.email, partner_email=link_request.email,
external_user_id=link_request.external_user_id, external_user_id=link_request.external_user_id,
) )
Session.commit() Session.commit()
LOG.i( LOG.i(
f"Created new partner_user for partner:{partner.id} user:{sl_user.id} external_user_id:{link_request.external_user_id}. PartnerUser.id is {res.id}" f"Created new partner_user for partner:{partner.id} user:{sl_user.id} external_user_id:{link_request.external_user_id}. PartnerUser.id is {res.id}"

View file

@ -1,8 +1,10 @@
from typing import Optional from typing import Optional
import arrow
from arrow import Arrow from arrow import Arrow
from app.models import PartnerUser, PartnerSubscription, User from app import config
from app.models import PartnerUser, PartnerSubscription, User, Job
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
@ -15,6 +17,11 @@ def create_partner_user(
partner_email=partner_email, partner_email=partner_email,
external_user_id=external_user_id, external_user_id=external_user_id,
) )
Job.create(
name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
payload={"user_id": user.id},
run_at=arrow.now(),
)
emit_user_audit_log( emit_user_audit_log(
user=user, user=user,
action=UserAuditLogAction.LinkAccount, action=UserAuditLogAction.LinkAccount,

View file

@ -2,11 +2,9 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from flask import url_for from flask import url_for
from typing import Optional from typing import Optional
import arrow
from app import config
from app.errors import LinkException from app.errors import LinkException
from app.models import User, Partner, Job from app.models import User, Partner
from app.proton.proton_client import ProtonClient, ProtonUser from app.proton.proton_client import ProtonClient, ProtonUser
from app.account_linking import ( from app.account_linking import (
process_login_case, process_login_case,
@ -43,21 +41,12 @@ class ProtonCallbackHandler:
def __init__(self, proton_client: ProtonClient): def __init__(self, proton_client: ProtonClient):
self.proton_client = proton_client self.proton_client = proton_client
def _initial_alias_sync(self, user: User):
Job.create(
name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
payload={"user_id": user.id},
run_at=arrow.now(),
commit=True,
)
def handle_login(self, partner: Partner) -> ProtonCallbackResult: def handle_login(self, partner: Partner) -> ProtonCallbackResult:
try: try:
user = self.__get_partner_user() user = self.__get_partner_user()
if user is None: if user is None:
return generate_account_not_allowed_to_log_in() return generate_account_not_allowed_to_log_in()
res = process_login_case(user, partner) res = process_login_case(user, partner)
self._initial_alias_sync(res.user)
return ProtonCallbackResult( return ProtonCallbackResult(
redirect_to_login=False, redirect_to_login=False,
flash_message=None, flash_message=None,
@ -86,7 +75,6 @@ class ProtonCallbackHandler:
if user is None: if user is None:
return generate_account_not_allowed_to_log_in() return generate_account_not_allowed_to_log_in()
res = process_link_case(user, current_user, partner) res = process_link_case(user, current_user, partner)
self._initial_alias_sync(res.user)
return ProtonCallbackResult( return ProtonCallbackResult(
redirect_to_login=False, redirect_to_login=False,
flash_message="Account successfully linked", flash_message="Account successfully linked",

View file

@ -25,15 +25,17 @@ class MockProtonClient(ProtonClient):
return self.user return self.user
def check_initial_sync_job(user: User): def check_initial_sync_job(user: User, expected: bool):
found = False
for job in Job.yield_per_query(10).filter_by( for job in Job.yield_per_query(10).filter_by(
name=config.JOB_SEND_ALIAS_CREATION_EVENTS, name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
state=JobState.ready.value, state=JobState.ready.value,
): ):
if job.payload.get("user_id") == user.id: if job.payload.get("user_id") == user.id:
found = True
Job.delete(job.id) Job.delete(job.id)
return break
assert False assert expected == found
def test_proton_callback_handler_unexistant_sl_user(): def test_proton_callback_handler_unexistant_sl_user():
@ -69,10 +71,9 @@ def test_proton_callback_handler_unexistant_sl_user():
) )
assert partner_user is not None assert partner_user is not None
assert partner_user.external_user_id == external_id assert partner_user.external_user_id == external_id
check_initial_sync_job(res.user)
def test_proton_callback_handler_existant_sl_user(): def test_proton_callback_handler_existing_sl_user():
email = random_email() email = random_email()
sl_user = User.create(email, commit=True) sl_user = User.create(email, commit=True)
@ -98,7 +99,43 @@ def test_proton_callback_handler_existant_sl_user():
sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner().id) sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner().id)
assert sa is not None assert sa is not None
assert sa.partner_email == user.email assert sa.partner_email == user.email
check_initial_sync_job(res.user) check_initial_sync_job(res.user, True)
def test_proton_callback_handler_linked_sl_user():
email = random_email()
external_id = random_string()
sl_user = User.create(email, commit=True)
PartnerUser.create(
user_id=sl_user.id,
partner_id=get_proton_partner().id,
external_user_id=external_id,
partner_email=email,
commit=True,
)
user = UserInformation(
email=email,
name=random_string(),
id=external_id,
plan=SLPlan(type=SLPlanType.Premium, expiration=Arrow.utcnow().shift(hours=2)),
)
handler = ProtonCallbackHandler(MockProtonClient(user=user))
res = handler.handle_login(get_proton_partner())
assert res.user is not None
assert res.user.id == sl_user.id
# Ensure the user is not marked as created from partner
assert User.FLAG_CREATED_FROM_PARTNER != (
res.user.flags & User.FLAG_CREATED_FROM_PARTNER
)
assert res.user.notification is True
assert res.user.trial_end is not None
sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner().id)
assert sa is not None
assert sa.partner_email == user.email
check_initial_sync_job(res.user, False)
def test_proton_callback_handler_none_user_login(): def test_proton_callback_handler_none_user_login():