diff --git a/app/constants.py b/app/constants.py index a2b787f7..48f44340 100644 --- a/app/constants.py +++ b/app/constants.py @@ -16,3 +16,4 @@ class JobType(enum.Enum): SEND_PROTON_WELCOME_1 = "proton-welcome-1" SEND_ALIAS_CREATION_EVENTS = "send-alias-creation-events" SEND_EVENT_TO_WEBHOOK = "send-event-to-webhook" + SYNC_SUBSCRIPTION = "sync-subscription" diff --git a/app/jobs/sync_subscription_job.py b/app/jobs/sync_subscription_job.py new file mode 100644 index 00000000..c42f467d --- /dev/null +++ b/app/jobs/sync_subscription_job.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from typing import Optional + +import arrow + +from app.constants import JobType +from app.errors import ProtonPartnerNotSetUp +from app.events.generated import event_pb2 +from app.events.generated.event_pb2 import EventContent, UserPlanChanged +from app.models import ( + User, + Job, + PartnerUser, + JobPriority, +) +from app.proton.proton_partner import get_proton_partner +from events.event_sink import EventSink + + +class SyncSubscriptionJob: + def __init__(self, user: User): + self._user: User = user + + def run(self, sink: EventSink) -> bool: + # Check if the current user has a partner_id + try: + proton_partner_id = get_proton_partner().id + except ProtonPartnerNotSetUp: + return False + + # It has. Retrieve the information for the PartnerUser + partner_user = PartnerUser.get_by( + user_id=self._user.id, partner_id=proton_partner_id + ) + if partner_user is None: + return True + + if self._user.lifetime: + content = UserPlanChanged(lifetime=True) + else: + plan_end = self._user.get_active_subscription_end( + include_partner_subscription=False + ) + if plan_end: + content = UserPlanChanged(plan_end_time=plan_end.timestamp) + else: + content = UserPlanChanged() + + event = event_pb2.Event( + user_id=self._user.id, + external_user_id=partner_user.external_user_id, + partner_id=partner_user.partner_id, + content=EventContent(user_plan_change=content), + ) + + serialized = event.SerializeToString() + return sink.send_data_to_webhook(serialized) + + @staticmethod + def create_from_job(job: Job) -> Optional[SyncSubscriptionJob]: + user = User.get(job.payload["user_id"]) + if not user: + return None + + return SyncSubscriptionJob(user=user) + + def store_job_in_db( + self, + run_at: Optional[arrow.Arrow], + priority: JobPriority = JobPriority.Default, + commit: bool = True, + ) -> Job: + return Job.create( + name=JobType.SYNC_SUBSCRIPTION.value, + payload={"user_id": self._user.id}, + priority=priority, + run_at=run_at if run_at is not None else arrow.now(), + commit=commit, + ) diff --git a/job_runner.py b/job_runner.py index 23ed4087..3d77f55d 100644 --- a/job_runner.py +++ b/job_runner.py @@ -24,6 +24,7 @@ from app.import_utils import handle_batch_import from app.jobs.event_jobs import send_alias_creation_events_for_user from app.jobs.export_user_data_job import ExportUserDataJob from app.jobs.send_event_job import SendEventToWebhookJob +from app.jobs.sync_subscription_job import SyncSubscriptionJob from app.log import LOG from app.models import User, Job, BatchImport, Mailbox, CustomDomain, JobState from app.monitor_utils import send_version_event @@ -317,6 +318,10 @@ def process_job(job: Job): send_job = SendEventToWebhookJob.create_from_job(job) if send_job: send_job.run(HttpEventSink()) + elif job.name == JobType.SYNC_SUBSCRIPTION.value: + sync_job = SyncSubscriptionJob.create_from_job(job) + if sync_job: + sync_job.run(HttpEventSink()) else: LOG.e("Unknown job name %s", job.name) diff --git a/oneshot/schedule_sync_subscription_job.py b/oneshot/schedule_sync_subscription_job.py new file mode 100644 index 00000000..de8179ac --- /dev/null +++ b/oneshot/schedule_sync_subscription_job.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +import argparse +import sys +import time + +from sqlalchemy import func +from typing import Optional + +from app.db import Session +from app.jobs.sync_subscription_job import SyncSubscriptionJob +from app.models import PartnerUser, User, JobPriority + + +def process(start_pu_id: int, end_pu_id: int, step: int, only_lifetime: bool): + print( + f"Checking partner user {start_pu_id} to {end_pu_id} (step={step}) (only_lifetime={only_lifetime})" + ) + start_time = time.time() + processed = 0 + for batch_start in range(start_pu_id, end_pu_id, step): + batch_end = min(batch_start + step, end_pu_id) + query = ( + Session.query(User) + .join(PartnerUser, PartnerUser.user_id == User.id) + .filter(PartnerUser.id >= batch_start, PartnerUser.id < batch_end) + ) + if only_lifetime: + query = query.filter( + User.lifetime == True, # noqa :E712 + ) + users = query.all() + for user in users: + job = SyncSubscriptionJob(user) + job.store_job_in_db(priority=JobPriority.Low, run_at=None, commit=False) + processed += 1 + Session.flush() + Session.commit() + elapsed = time.time() - start_time + time_per_user = elapsed / processed + remaining = end_pu_id - batch_end + mins_remaining = remaining / time_per_user + print( + f"PartnerUser {batch_start}/{end_pu_id} | processed = {processed} | {mins_remaining:.2f} mins remaining" + ) + + +def main(): + parser = argparse.ArgumentParser( + prog="Schedule Sync User Jobs", + description="Create jobs to sync user subscriptions", + ) + parser.add_argument( + "-s", "--start_pu_id", default=0, type=int, help="Initial partner_user_id" + ) + parser.add_argument( + "-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id" + ) + parser.add_argument("-t", "--step", default=10000, type=int, help="Step to use") + parser.add_argument("-u", "--user", default="", type=str, help="User to sync") + parser.add_argument( + "-l", "--lifetime", action="store_true", help="Only sync lifetime users" + ) + + args = parser.parse_args() + start_pu_id = args.start_pu_id + end_pu_id = args.end_pu_id + user_id = args.user + only_lifetime = args.lifetime + step = args.step + + if not end_pu_id: + end_pu_id = Session.query(func.max(PartnerUser.id)).scalar() + + if user_id: + try: + user_id = int(user_id) + except ValueError: + user = User.get_by(email=user_id) + if not user: + print(f"User {user_id} not found") + sys.exit(1) + user_id = user.id + print(f"Limiting to user {user_id}") + partner_user: Optional[PartnerUser] = PartnerUser.get_by(user_id=user_id) + if not partner_user: + print(f"Could not find PartnerUser for user_id={user_id}") + sys.exit(1) + + # So we only have one loop + step = 1 + start_pu_id = partner_user.id + end_pu_id = partner_user.id + 1 # Necessary to at least have 1 result + + process( + start_pu_id=start_pu_id, + end_pu_id=end_pu_id, + step=step, + only_lifetime=only_lifetime, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/jobs/test_sync_subscription_job.py b/tests/jobs/test_sync_subscription_job.py new file mode 100644 index 00000000..d5bebbec --- /dev/null +++ b/tests/jobs/test_sync_subscription_job.py @@ -0,0 +1,90 @@ +import arrow +from typing import List + +from app.constants import JobType +from app.events.generated import event_pb2 +from app.jobs.sync_subscription_job import SyncSubscriptionJob +from app.models import PartnerUser, JobPriority, SyncEvent +from app.proton.proton_partner import get_proton_partner +from events.event_sink import EventSink +from tests.utils import create_new_user, random_token + + +class InMemorySink(EventSink): + def __init__(self): + self.events = [] + + def process(self, event: SyncEvent) -> bool: + raise RuntimeError("Should not be called") + + def send_data_to_webhook(self, data: bytes) -> bool: + self.events.append(data) + return True + + +def test_serialize_and_deserialize_job(): + user = create_new_user() + run_at = arrow.now().shift(hours=10) + priority = JobPriority.High + db_job = SyncSubscriptionJob(user).store_job_in_db(run_at=run_at, priority=priority) + assert db_job.run_at == run_at + assert db_job.priority == priority + assert db_job.name == JobType.SYNC_SUBSCRIPTION.value + + job = SyncSubscriptionJob.create_from_job(db_job) + assert job._user.id == user.id + + +def _run_send_event_test(partner_user: PartnerUser) -> event_pb2.Event: + job = SyncSubscriptionJob(partner_user.user) + sink = InMemorySink() + assert job.run(sink) + + sent_events: List[bytes] = sink.events + assert len(sent_events) == 1 + + decoded = event_pb2.Event() + decoded.ParseFromString(sent_events[0]) + + return decoded + + +def test_send_event_to_webhook_free(): + user = create_new_user() + external_user_id = random_token(10) + partner_user = PartnerUser.create( + user_id=user.id, + partner_id=get_proton_partner().id, + external_user_id=external_user_id, + flush=True, + ) + + res = _run_send_event_test(partner_user) + + assert res.user_id == user.id + assert res.partner_id == partner_user.partner_id + assert res.external_user_id == external_user_id + assert res.content == event_pb2.EventContent( + user_plan_change=event_pb2.UserPlanChanged() + ) + + +def test_send_event_to_webhook_lifetime(): + user = create_new_user() + user.lifetime = True + external_user_id = random_token(10) + partner_user = PartnerUser.create( + user_id=user.id, + partner_id=get_proton_partner().id, + external_user_id=external_user_id, + commit=True, + ) + + res = _run_send_event_test(partner_user) + + assert res.user_id == user.id + assert res.partner_id == partner_user.partner_id + assert res.external_user_id == external_user_id + assert res.content == event_pb2.EventContent( + user_plan_change=event_pb2.UserPlanChanged(lifetime=True) + )