mirror of
https://github.com/simple-login/app.git
synced 2025-09-17 12:05:52 +08:00
feat: add job to schedule subscription sync (#2443)
This commit is contained in:
parent
8a69207828
commit
18290c840c
5 changed files with 280 additions and 0 deletions
|
@ -16,3 +16,4 @@ class JobType(enum.Enum):
|
|||
SEND_PROTON_WELCOME_1 = "proton-welcome-1"
|
||||
SEND_ALIAS_CREATION_EVENTS = "send-alias-creation-events"
|
||||
SEND_EVENT_TO_WEBHOOK = "send-event-to-webhook"
|
||||
SYNC_SUBSCRIPTION = "sync-subscription"
|
||||
|
|
80
app/jobs/sync_subscription_job.py
Normal file
80
app/jobs/sync_subscription_job.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import arrow
|
||||
|
||||
from app.constants import JobType
|
||||
from app.errors import ProtonPartnerNotSetUp
|
||||
from app.events.generated import event_pb2
|
||||
from app.events.generated.event_pb2 import EventContent, UserPlanChanged
|
||||
from app.models import (
|
||||
User,
|
||||
Job,
|
||||
PartnerUser,
|
||||
JobPriority,
|
||||
)
|
||||
from app.proton.proton_partner import get_proton_partner
|
||||
from events.event_sink import EventSink
|
||||
|
||||
|
||||
class SyncSubscriptionJob:
|
||||
def __init__(self, user: User):
|
||||
self._user: User = user
|
||||
|
||||
def run(self, sink: EventSink) -> bool:
|
||||
# Check if the current user has a partner_id
|
||||
try:
|
||||
proton_partner_id = get_proton_partner().id
|
||||
except ProtonPartnerNotSetUp:
|
||||
return False
|
||||
|
||||
# It has. Retrieve the information for the PartnerUser
|
||||
partner_user = PartnerUser.get_by(
|
||||
user_id=self._user.id, partner_id=proton_partner_id
|
||||
)
|
||||
if partner_user is None:
|
||||
return True
|
||||
|
||||
if self._user.lifetime:
|
||||
content = UserPlanChanged(lifetime=True)
|
||||
else:
|
||||
plan_end = self._user.get_active_subscription_end(
|
||||
include_partner_subscription=False
|
||||
)
|
||||
if plan_end:
|
||||
content = UserPlanChanged(plan_end_time=plan_end.timestamp)
|
||||
else:
|
||||
content = UserPlanChanged()
|
||||
|
||||
event = event_pb2.Event(
|
||||
user_id=self._user.id,
|
||||
external_user_id=partner_user.external_user_id,
|
||||
partner_id=partner_user.partner_id,
|
||||
content=EventContent(user_plan_change=content),
|
||||
)
|
||||
|
||||
serialized = event.SerializeToString()
|
||||
return sink.send_data_to_webhook(serialized)
|
||||
|
||||
@staticmethod
|
||||
def create_from_job(job: Job) -> Optional[SyncSubscriptionJob]:
|
||||
user = User.get(job.payload["user_id"])
|
||||
if not user:
|
||||
return None
|
||||
|
||||
return SyncSubscriptionJob(user=user)
|
||||
|
||||
def store_job_in_db(
|
||||
self,
|
||||
run_at: Optional[arrow.Arrow],
|
||||
priority: JobPriority = JobPriority.Default,
|
||||
commit: bool = True,
|
||||
) -> Job:
|
||||
return Job.create(
|
||||
name=JobType.SYNC_SUBSCRIPTION.value,
|
||||
payload={"user_id": self._user.id},
|
||||
priority=priority,
|
||||
run_at=run_at if run_at is not None else arrow.now(),
|
||||
commit=commit,
|
||||
)
|
|
@ -24,6 +24,7 @@ from app.import_utils import handle_batch_import
|
|||
from app.jobs.event_jobs import send_alias_creation_events_for_user
|
||||
from app.jobs.export_user_data_job import ExportUserDataJob
|
||||
from app.jobs.send_event_job import SendEventToWebhookJob
|
||||
from app.jobs.sync_subscription_job import SyncSubscriptionJob
|
||||
from app.log import LOG
|
||||
from app.models import User, Job, BatchImport, Mailbox, CustomDomain, JobState
|
||||
from app.monitor_utils import send_version_event
|
||||
|
@ -317,6 +318,10 @@ def process_job(job: Job):
|
|||
send_job = SendEventToWebhookJob.create_from_job(job)
|
||||
if send_job:
|
||||
send_job.run(HttpEventSink())
|
||||
elif job.name == JobType.SYNC_SUBSCRIPTION.value:
|
||||
sync_job = SyncSubscriptionJob.create_from_job(job)
|
||||
if sync_job:
|
||||
sync_job.run(HttpEventSink())
|
||||
else:
|
||||
LOG.e("Unknown job name %s", job.name)
|
||||
|
||||
|
|
104
oneshot/schedule_sync_subscription_job.py
Normal file
104
oneshot/schedule_sync_subscription_job.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
|
||||
from sqlalchemy import func
|
||||
from typing import Optional
|
||||
|
||||
from app.db import Session
|
||||
from app.jobs.sync_subscription_job import SyncSubscriptionJob
|
||||
from app.models import PartnerUser, User, JobPriority
|
||||
|
||||
|
||||
def process(start_pu_id: int, end_pu_id: int, step: int, only_lifetime: bool):
|
||||
print(
|
||||
f"Checking partner user {start_pu_id} to {end_pu_id} (step={step}) (only_lifetime={only_lifetime})"
|
||||
)
|
||||
start_time = time.time()
|
||||
processed = 0
|
||||
for batch_start in range(start_pu_id, end_pu_id, step):
|
||||
batch_end = min(batch_start + step, end_pu_id)
|
||||
query = (
|
||||
Session.query(User)
|
||||
.join(PartnerUser, PartnerUser.user_id == User.id)
|
||||
.filter(PartnerUser.id >= batch_start, PartnerUser.id < batch_end)
|
||||
)
|
||||
if only_lifetime:
|
||||
query = query.filter(
|
||||
User.lifetime == True, # noqa :E712
|
||||
)
|
||||
users = query.all()
|
||||
for user in users:
|
||||
job = SyncSubscriptionJob(user)
|
||||
job.store_job_in_db(priority=JobPriority.Low, run_at=None, commit=False)
|
||||
processed += 1
|
||||
Session.flush()
|
||||
Session.commit()
|
||||
elapsed = time.time() - start_time
|
||||
time_per_user = elapsed / processed
|
||||
remaining = end_pu_id - batch_end
|
||||
mins_remaining = remaining / time_per_user
|
||||
print(
|
||||
f"PartnerUser {batch_start}/{end_pu_id} | processed = {processed} | {mins_remaining:.2f} mins remaining"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="Schedule Sync User Jobs",
|
||||
description="Create jobs to sync user subscriptions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s", "--start_pu_id", default=0, type=int, help="Initial partner_user_id"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id"
|
||||
)
|
||||
parser.add_argument("-t", "--step", default=10000, type=int, help="Step to use")
|
||||
parser.add_argument("-u", "--user", default="", type=str, help="User to sync")
|
||||
parser.add_argument(
|
||||
"-l", "--lifetime", action="store_true", help="Only sync lifetime users"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
start_pu_id = args.start_pu_id
|
||||
end_pu_id = args.end_pu_id
|
||||
user_id = args.user
|
||||
only_lifetime = args.lifetime
|
||||
step = args.step
|
||||
|
||||
if not end_pu_id:
|
||||
end_pu_id = Session.query(func.max(PartnerUser.id)).scalar()
|
||||
|
||||
if user_id:
|
||||
try:
|
||||
user_id = int(user_id)
|
||||
except ValueError:
|
||||
user = User.get_by(email=user_id)
|
||||
if not user:
|
||||
print(f"User {user_id} not found")
|
||||
sys.exit(1)
|
||||
user_id = user.id
|
||||
print(f"Limiting to user {user_id}")
|
||||
partner_user: Optional[PartnerUser] = PartnerUser.get_by(user_id=user_id)
|
||||
if not partner_user:
|
||||
print(f"Could not find PartnerUser for user_id={user_id}")
|
||||
sys.exit(1)
|
||||
|
||||
# So we only have one loop
|
||||
step = 1
|
||||
start_pu_id = partner_user.id
|
||||
end_pu_id = partner_user.id + 1 # Necessary to at least have 1 result
|
||||
|
||||
process(
|
||||
start_pu_id=start_pu_id,
|
||||
end_pu_id=end_pu_id,
|
||||
step=step,
|
||||
only_lifetime=only_lifetime,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
90
tests/jobs/test_sync_subscription_job.py
Normal file
90
tests/jobs/test_sync_subscription_job.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
import arrow
|
||||
from typing import List
|
||||
|
||||
from app.constants import JobType
|
||||
from app.events.generated import event_pb2
|
||||
from app.jobs.sync_subscription_job import SyncSubscriptionJob
|
||||
from app.models import PartnerUser, JobPriority, SyncEvent
|
||||
from app.proton.proton_partner import get_proton_partner
|
||||
from events.event_sink import EventSink
|
||||
from tests.utils import create_new_user, random_token
|
||||
|
||||
|
||||
class InMemorySink(EventSink):
|
||||
def __init__(self):
|
||||
self.events = []
|
||||
|
||||
def process(self, event: SyncEvent) -> bool:
|
||||
raise RuntimeError("Should not be called")
|
||||
|
||||
def send_data_to_webhook(self, data: bytes) -> bool:
|
||||
self.events.append(data)
|
||||
return True
|
||||
|
||||
|
||||
def test_serialize_and_deserialize_job():
|
||||
user = create_new_user()
|
||||
run_at = arrow.now().shift(hours=10)
|
||||
priority = JobPriority.High
|
||||
db_job = SyncSubscriptionJob(user).store_job_in_db(run_at=run_at, priority=priority)
|
||||
assert db_job.run_at == run_at
|
||||
assert db_job.priority == priority
|
||||
assert db_job.name == JobType.SYNC_SUBSCRIPTION.value
|
||||
|
||||
job = SyncSubscriptionJob.create_from_job(db_job)
|
||||
assert job._user.id == user.id
|
||||
|
||||
|
||||
def _run_send_event_test(partner_user: PartnerUser) -> event_pb2.Event:
|
||||
job = SyncSubscriptionJob(partner_user.user)
|
||||
sink = InMemorySink()
|
||||
assert job.run(sink)
|
||||
|
||||
sent_events: List[bytes] = sink.events
|
||||
assert len(sent_events) == 1
|
||||
|
||||
decoded = event_pb2.Event()
|
||||
decoded.ParseFromString(sent_events[0])
|
||||
|
||||
return decoded
|
||||
|
||||
|
||||
def test_send_event_to_webhook_free():
|
||||
user = create_new_user()
|
||||
external_user_id = random_token(10)
|
||||
partner_user = PartnerUser.create(
|
||||
user_id=user.id,
|
||||
partner_id=get_proton_partner().id,
|
||||
external_user_id=external_user_id,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
res = _run_send_event_test(partner_user)
|
||||
|
||||
assert res.user_id == user.id
|
||||
assert res.partner_id == partner_user.partner_id
|
||||
assert res.external_user_id == external_user_id
|
||||
assert res.content == event_pb2.EventContent(
|
||||
user_plan_change=event_pb2.UserPlanChanged()
|
||||
)
|
||||
|
||||
|
||||
def test_send_event_to_webhook_lifetime():
|
||||
user = create_new_user()
|
||||
user.lifetime = True
|
||||
external_user_id = random_token(10)
|
||||
partner_user = PartnerUser.create(
|
||||
user_id=user.id,
|
||||
partner_id=get_proton_partner().id,
|
||||
external_user_id=external_user_id,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
res = _run_send_event_test(partner_user)
|
||||
|
||||
assert res.user_id == user.id
|
||||
assert res.partner_id == partner_user.partner_id
|
||||
assert res.external_user_id == external_user_id
|
||||
assert res.content == event_pb2.EventContent(
|
||||
user_plan_change=event_pb2.UserPlanChanged(lifetime=True)
|
||||
)
|
Loading…
Add table
Reference in a new issue