feat: add job to schedule subscription sync (#2443)

This commit is contained in:
Carlos Quintana 2025-04-30 10:08:11 +02:00 committed by Carlos Quintana
parent 8a69207828
commit 18290c840c
No known key found for this signature in database
GPG key ID: 15E73DCC410679F8
5 changed files with 280 additions and 0 deletions

View file

@ -16,3 +16,4 @@ class JobType(enum.Enum):
SEND_PROTON_WELCOME_1 = "proton-welcome-1"
SEND_ALIAS_CREATION_EVENTS = "send-alias-creation-events"
SEND_EVENT_TO_WEBHOOK = "send-event-to-webhook"
SYNC_SUBSCRIPTION = "sync-subscription"

View file

@ -0,0 +1,80 @@
from __future__ import annotations
from typing import Optional
import arrow
from app.constants import JobType
from app.errors import ProtonPartnerNotSetUp
from app.events.generated import event_pb2
from app.events.generated.event_pb2 import EventContent, UserPlanChanged
from app.models import (
User,
Job,
PartnerUser,
JobPriority,
)
from app.proton.proton_partner import get_proton_partner
from events.event_sink import EventSink
class SyncSubscriptionJob:
def __init__(self, user: User):
self._user: User = user
def run(self, sink: EventSink) -> bool:
# Check if the current user has a partner_id
try:
proton_partner_id = get_proton_partner().id
except ProtonPartnerNotSetUp:
return False
# It has. Retrieve the information for the PartnerUser
partner_user = PartnerUser.get_by(
user_id=self._user.id, partner_id=proton_partner_id
)
if partner_user is None:
return True
if self._user.lifetime:
content = UserPlanChanged(lifetime=True)
else:
plan_end = self._user.get_active_subscription_end(
include_partner_subscription=False
)
if plan_end:
content = UserPlanChanged(plan_end_time=plan_end.timestamp)
else:
content = UserPlanChanged()
event = event_pb2.Event(
user_id=self._user.id,
external_user_id=partner_user.external_user_id,
partner_id=partner_user.partner_id,
content=EventContent(user_plan_change=content),
)
serialized = event.SerializeToString()
return sink.send_data_to_webhook(serialized)
@staticmethod
def create_from_job(job: Job) -> Optional[SyncSubscriptionJob]:
user = User.get(job.payload["user_id"])
if not user:
return None
return SyncSubscriptionJob(user=user)
def store_job_in_db(
self,
run_at: Optional[arrow.Arrow],
priority: JobPriority = JobPriority.Default,
commit: bool = True,
) -> Job:
return Job.create(
name=JobType.SYNC_SUBSCRIPTION.value,
payload={"user_id": self._user.id},
priority=priority,
run_at=run_at if run_at is not None else arrow.now(),
commit=commit,
)

View file

@ -24,6 +24,7 @@ from app.import_utils import handle_batch_import
from app.jobs.event_jobs import send_alias_creation_events_for_user
from app.jobs.export_user_data_job import ExportUserDataJob
from app.jobs.send_event_job import SendEventToWebhookJob
from app.jobs.sync_subscription_job import SyncSubscriptionJob
from app.log import LOG
from app.models import User, Job, BatchImport, Mailbox, CustomDomain, JobState
from app.monitor_utils import send_version_event
@ -317,6 +318,10 @@ def process_job(job: Job):
send_job = SendEventToWebhookJob.create_from_job(job)
if send_job:
send_job.run(HttpEventSink())
elif job.name == JobType.SYNC_SUBSCRIPTION.value:
sync_job = SyncSubscriptionJob.create_from_job(job)
if sync_job:
sync_job.run(HttpEventSink())
else:
LOG.e("Unknown job name %s", job.name)

View file

@ -0,0 +1,104 @@
#!/usr/bin/env python3
import argparse
import sys
import time
from sqlalchemy import func
from typing import Optional
from app.db import Session
from app.jobs.sync_subscription_job import SyncSubscriptionJob
from app.models import PartnerUser, User, JobPriority
def process(start_pu_id: int, end_pu_id: int, step: int, only_lifetime: bool):
print(
f"Checking partner user {start_pu_id} to {end_pu_id} (step={step}) (only_lifetime={only_lifetime})"
)
start_time = time.time()
processed = 0
for batch_start in range(start_pu_id, end_pu_id, step):
batch_end = min(batch_start + step, end_pu_id)
query = (
Session.query(User)
.join(PartnerUser, PartnerUser.user_id == User.id)
.filter(PartnerUser.id >= batch_start, PartnerUser.id < batch_end)
)
if only_lifetime:
query = query.filter(
User.lifetime == True, # noqa :E712
)
users = query.all()
for user in users:
job = SyncSubscriptionJob(user)
job.store_job_in_db(priority=JobPriority.Low, run_at=None, commit=False)
processed += 1
Session.flush()
Session.commit()
elapsed = time.time() - start_time
time_per_user = elapsed / processed
remaining = end_pu_id - batch_end
mins_remaining = remaining / time_per_user
print(
f"PartnerUser {batch_start}/{end_pu_id} | processed = {processed} | {mins_remaining:.2f} mins remaining"
)
def main():
parser = argparse.ArgumentParser(
prog="Schedule Sync User Jobs",
description="Create jobs to sync user subscriptions",
)
parser.add_argument(
"-s", "--start_pu_id", default=0, type=int, help="Initial partner_user_id"
)
parser.add_argument(
"-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id"
)
parser.add_argument("-t", "--step", default=10000, type=int, help="Step to use")
parser.add_argument("-u", "--user", default="", type=str, help="User to sync")
parser.add_argument(
"-l", "--lifetime", action="store_true", help="Only sync lifetime users"
)
args = parser.parse_args()
start_pu_id = args.start_pu_id
end_pu_id = args.end_pu_id
user_id = args.user
only_lifetime = args.lifetime
step = args.step
if not end_pu_id:
end_pu_id = Session.query(func.max(PartnerUser.id)).scalar()
if user_id:
try:
user_id = int(user_id)
except ValueError:
user = User.get_by(email=user_id)
if not user:
print(f"User {user_id} not found")
sys.exit(1)
user_id = user.id
print(f"Limiting to user {user_id}")
partner_user: Optional[PartnerUser] = PartnerUser.get_by(user_id=user_id)
if not partner_user:
print(f"Could not find PartnerUser for user_id={user_id}")
sys.exit(1)
# So we only have one loop
step = 1
start_pu_id = partner_user.id
end_pu_id = partner_user.id + 1 # Necessary to at least have 1 result
process(
start_pu_id=start_pu_id,
end_pu_id=end_pu_id,
step=step,
only_lifetime=only_lifetime,
)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,90 @@
import arrow
from typing import List
from app.constants import JobType
from app.events.generated import event_pb2
from app.jobs.sync_subscription_job import SyncSubscriptionJob
from app.models import PartnerUser, JobPriority, SyncEvent
from app.proton.proton_partner import get_proton_partner
from events.event_sink import EventSink
from tests.utils import create_new_user, random_token
class InMemorySink(EventSink):
def __init__(self):
self.events = []
def process(self, event: SyncEvent) -> bool:
raise RuntimeError("Should not be called")
def send_data_to_webhook(self, data: bytes) -> bool:
self.events.append(data)
return True
def test_serialize_and_deserialize_job():
user = create_new_user()
run_at = arrow.now().shift(hours=10)
priority = JobPriority.High
db_job = SyncSubscriptionJob(user).store_job_in_db(run_at=run_at, priority=priority)
assert db_job.run_at == run_at
assert db_job.priority == priority
assert db_job.name == JobType.SYNC_SUBSCRIPTION.value
job = SyncSubscriptionJob.create_from_job(db_job)
assert job._user.id == user.id
def _run_send_event_test(partner_user: PartnerUser) -> event_pb2.Event:
job = SyncSubscriptionJob(partner_user.user)
sink = InMemorySink()
assert job.run(sink)
sent_events: List[bytes] = sink.events
assert len(sent_events) == 1
decoded = event_pb2.Event()
decoded.ParseFromString(sent_events[0])
return decoded
def test_send_event_to_webhook_free():
user = create_new_user()
external_user_id = random_token(10)
partner_user = PartnerUser.create(
user_id=user.id,
partner_id=get_proton_partner().id,
external_user_id=external_user_id,
flush=True,
)
res = _run_send_event_test(partner_user)
assert res.user_id == user.id
assert res.partner_id == partner_user.partner_id
assert res.external_user_id == external_user_id
assert res.content == event_pb2.EventContent(
user_plan_change=event_pb2.UserPlanChanged()
)
def test_send_event_to_webhook_lifetime():
user = create_new_user()
user.lifetime = True
external_user_id = random_token(10)
partner_user = PartnerUser.create(
user_id=user.id,
partner_id=get_proton_partner().id,
external_user_id=external_user_id,
commit=True,
)
res = _run_send_event_test(partner_user)
assert res.user_id == user.id
assert res.partner_id == partner_user.partner_id
assert res.external_user_id == external_user_id
assert res.content == event_pb2.EventContent(
user_plan_change=event_pb2.UserPlanChanged(lifetime=True)
)