diff --git a/app/account_linking.py b/app/account_linking.py index 02160e64..56b1da83 100644 --- a/app/account_linking.py +++ b/app/account_linking.py @@ -3,12 +3,15 @@ from dataclasses import dataclass from enum import Enum from typing import Optional +import arrow from arrow import Arrow from newrelic import agent from sqlalchemy import or_ from app.db import Session from app.email_utils import send_welcome_email +from app.events.event_dispatcher import EventDispatcher +from app.events.generated.event_pb2 import UserPlanChanged, EventContent from app.partner_user_utils import create_partner_user, create_partner_subscription from app.utils import sanitize_email, canonicalize_email from app.errors import ( @@ -54,6 +57,21 @@ class LinkResult: strategy: str +def send_user_plan_changed_event(partner_user: PartnerUser) -> Optional[int]: + subscription_end = partner_user.user.get_active_subscription_end( + include_partner_subscription=False + ) + end_timestamp = None + if partner_user.user.lifetime: + end_timestamp = arrow.get("2038-01-01").timestamp + elif subscription_end: + end_timestamp = subscription_end.timestamp + event = UserPlanChanged(plan_end_time=end_timestamp) + EventDispatcher.send_event(partner_user.user, EventContent(user_plan_change=event)) + Session.flush() + return end_timestamp + + def set_plan_for_partner_user(partner_user: PartnerUser, plan: SLPlan): sub = PartnerSubscription.get_by(partner_user_id=partner_user.id) if plan.type == SLPlanType.Free: @@ -88,6 +106,8 @@ def set_plan_for_partner_user(partner_user: PartnerUser, plan: SLPlan): action=UserAuditLogAction.SubscriptionExtended, message="Extended partner subscription", ) + Session.flush() + send_user_plan_changed_event(partner_user) Session.commit() diff --git a/oneshot/send_plan_change_events.py b/oneshot/send_plan_change_events.py index 8db1870a..50e38a7b 100644 --- a/oneshot/send_plan_change_events.py +++ b/oneshot/send_plan_change_events.py @@ -5,8 +5,7 @@ import time import arrow from sqlalchemy import func -from app.events.event_dispatcher import EventDispatcher -from app.events.generated.event_pb2 import UserPlanChanged, EventContent +from app.account_linking import send_user_plan_changed_event from app.models import PartnerUser from app.db import Session @@ -39,21 +38,12 @@ for batch_start in range(pu_id_start, max_pu_id, step): ) ).all() for partner_user in partner_users: - subscription_end = partner_user.user.get_active_subscription_end( - include_partner_subscription=False - ) - end_timestamp = None - if partner_user.user.lifetime: - with_lifetime += 1 - end_timestamp = arrow.get("2038-01-01").timestamp - elif subscription_end: - with_premium += 1 - end_timestamp = subscription_end.timestamp - event = UserPlanChanged(plan_end_time=end_timestamp) - EventDispatcher.send_event( - partner_user.user, EventContent(user_plan_change=event) - ) - Session.flush() + subscription_end = send_user_plan_changed_event(partner_user) + if subscription_end is not None: + if subscription_end > arrow.get("2038-01-01").timestamp: + with_lifetime += 1 + else: + with_premium += 1 updated += 1 Session.commit() elapsed = time.time() - start_time