From deafb529fae5c46001dbd23bccf670d188277610 Mon Sep 17 00:00:00 2001 From: Bohdan Shtepan Date: Mon, 19 May 2025 14:59:50 +0200 Subject: [PATCH] feat: added abuser utils for main abuser (un)archiving flows. (#2468) * feat: added abuser utils for main abuser (un)archiving flows. * feat: update the abuser utils to use external secrets in conjuction with HMAC; add abuser audit log. * fix: revert the DB_URI part. * feat: address feedback. * feat: address feedback. * feat: address feedback. * feat: address feedback. * feat: address feedback. --- app/abuser_audit_log_utils.py | 26 + app/abuser_utils.py | 294 +++++++++ app/admin_model.py | 9 +- app/api/views/auth.py | 7 + app/auth/views/register.py | 5 +- app/config.py | 15 + app/constants.py | 2 + app/models.py | 31 + .../versions/2025_051618_e38002759d8f_.py | 68 +++ tests/test.env | 1 - tests/test_abuser_audit_log_utils.py | 27 + tests/test_abuser_utils.py | 566 ++++++++++++++++++ 12 files changed, 1047 insertions(+), 4 deletions(-) create mode 100644 app/abuser_audit_log_utils.py create mode 100644 app/abuser_utils.py create mode 100644 migrations/versions/2025_051618_e38002759d8f_.py create mode 100644 tests/test_abuser_audit_log_utils.py create mode 100644 tests/test_abuser_utils.py diff --git a/app/abuser_audit_log_utils.py b/app/abuser_audit_log_utils.py new file mode 100644 index 00000000..37436abe --- /dev/null +++ b/app/abuser_audit_log_utils.py @@ -0,0 +1,26 @@ +from enum import Enum +from typing import Optional + +from app.models import AbuserAuditLog + + +class AbuserAuditLogAction(Enum): + MarkAbuser = "mark_abuser" + UnmarkAbuser = "unmark_abuser" + GetAbuserBundles = "get_abuser_bundles" + + +def emit_abuser_audit_log( + user_id: int, + action: AbuserAuditLogAction, + message: str, + admin_id: Optional[int] = None, + commit: bool = False, +) -> None: + AbuserAuditLog.create( + user_id=user_id, + action=action.value, + message=message, + admin_id=admin_id if admin_id else None, + commit=commit, + ) diff --git a/app/abuser_utils.py b/app/abuser_utils.py new file mode 100644 index 00000000..e0cdcb52 --- /dev/null +++ b/app/abuser_utils.py @@ -0,0 +1,294 @@ +import hmac +import json +import secrets +from hashlib import sha256 +from typing import List, Dict, Optional + +from cryptography.exceptions import InvalidTag +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes as crypto_hashes +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.kdf.hkdf import HKDF + +from app import config, constants +from app.abuser_audit_log_utils import emit_abuser_audit_log, AbuserAuditLogAction +from app.db import Session +from app.log import LOG +from app.models import User, Alias, Mailbox, AbuserData, AbuserLookup + + +def _derive_key_for_identifier(master_key: bytes, identifier_address: str) -> bytes: + if not identifier_address or not isinstance(identifier_address, str): + raise ValueError( + "Identifier address must be a non-empty string for key derivation." + ) + + normalized_identifier = identifier_address.lower() + hkdf_info = (constants.HKDF_INFO_TEMPLATE % normalized_identifier).encode("utf-8") + hkdf = HKDF( + algorithm=crypto_hashes.SHA256(), + length=32, + salt=config.ABUSER_HKDF_SALT, + info=hkdf_info, + backend=default_backend(), + ) + + return hkdf.derive(master_key) + + +def check_if_abuser_email(new_address: str) -> Optional[AbuserLookup]: + """ + Returns AbuserLookup, if the given address (after hashing) is found in abuser_lookup. + """ + mac_key_bytes = config.MAC_KEY + normalized_address = new_address.lower() + check_hmac = hmac.new( + mac_key_bytes, normalized_address.encode("utf-8"), sha256 + ).hexdigest() + + return ( + AbuserLookup.filter(AbuserLookup.hashed_address == check_hmac).limit(1).first() + ) + + +def mark_user_as_abuser( + abuse_user: User, note: str, admin_id: Optional[int] = None +) -> None: + abuse_user.disabled = True + + emit_abuser_audit_log( + user_id=abuse_user.id, + action=AbuserAuditLogAction.MarkAbuser, + message=note, + admin_id=admin_id, + ) + Session.commit() + _store_abuse_data(abuse_user) + + +def _store_abuse_data(user: User) -> None: + """ + Archive the given abusive user's data and update blocklist/lookup tables. + """ + if not user.email: + raise ValueError(f"User ID {user.id} must have a primary email to be archived.") + + try: + primary_email: str = user.email.lower() + aliases: List[Alias] = ( + Alias.filter_by(user_id=user.id) + .enable_eagerloads(False) + .yield_per(500) + .all() + ) + mailboxes: List[Mailbox] = Mailbox.filter_by(user_id=user.id).all() + bundle = { + "account_id": str(user.id), + "email": primary_email, + "user_created_at": user.created_at.isoformat() if user.created_at else None, + "aliases": [ + { + "address": alias.email.lower() if alias.email else None, + "created_at": alias.created_at.isoformat() + if alias.created_at + else None, + } + for alias in aliases + if alias.email + ], + "mailboxes": [ + { + "address": mailbox.email.lower() if mailbox.email else None, + "created_at": mailbox.created_at.isoformat() + if mailbox.created_at + else None, + } + for mailbox in mailboxes + if mailbox.email + ], + } + bundle_json_bytes = json.dumps(bundle, sort_keys=True).encode("utf-8") + k_bundle_random = AESGCM.generate_key(bit_length=256) + aesgcm_bundle_enc = AESGCM(k_bundle_random) + nonce_bundle = secrets.token_bytes(12) + encrypted_bundle_data = nonce_bundle + aesgcm_bundle_enc.encrypt( + nonce_bundle, + bundle_json_bytes, + f"{constants.AEAD_AAD_DATA}.bundle".encode("utf-8"), + ) + abuser_data_entry = AbuserData( + user_id=user.id, encrypted_bundle=encrypted_bundle_data + ) + + Session.add(abuser_data_entry) + Session.flush() + + blob_id = abuser_data_entry.id + all_identifiers_raw = ( + [primary_email] + + [a.email for a in aliases if a.email] + + [m.email for m in mailboxes if m.email] + ) + seen_normalized_identifiers = set() + mac_key_bytes = config.MAC_KEY + master_key_bytes = config.MASTER_ENC_KEY + + for raw_identifier_address in all_identifiers_raw: + if not raw_identifier_address: + continue + + normalized_identifier = raw_identifier_address.lower() + + if normalized_identifier in seen_normalized_identifiers: + continue + + seen_normalized_identifiers.add(normalized_identifier) + identifier_hmac = hmac.new( + mac_key_bytes, normalized_identifier.encode("utf-8"), sha256 + ).hexdigest() + + k_identifier_derived = _derive_key_for_identifier( + master_key_bytes, normalized_identifier + ) + aesgcm_key_enc = AESGCM(k_identifier_derived) + nonce_key_encryption = secrets.token_bytes(12) + encrypted_k_bundle_for_this_identifier = ( + nonce_key_encryption + + aesgcm_key_enc.encrypt( + nonce_key_encryption, + k_bundle_random, + f"{constants.AEAD_AAD_DATA}.key".encode("utf-8"), + ) + ) + abuser_lookup_entry = AbuserLookup( + hashed_address=identifier_hmac, + abuser_data_id=blob_id, + bundle_k=encrypted_k_bundle_for_this_identifier, + ) + + Session.add(abuser_lookup_entry) + + Session.commit() + except Exception: + Session.rollback() + LOG.exception("Error during archive_abusive_user") + raise + + +def unmark_as_abusive_user( + user_id: int, note: str, admin_id: Optional[int] = None +) -> None: + """ + Fully remove abuser archive and lookup data for a given user_id. + This reverses the effects of archive_abusive_user(). + """ + LOG.i(f"Removing user {user_id} as an abuser.") + abuser_data_entry = AbuserData.filter_by(user_id=user_id).first() + + if abuser_data_entry: + Session.delete(abuser_data_entry) + + user = User.get(user_id) + user.disabled = False + + emit_abuser_audit_log( + user_id=user.id, + admin_id=admin_id, + action=AbuserAuditLogAction.UnmarkAbuser, + message=note, + ) + Session.commit() + + +def get_abuser_bundles_for_address(target_address: str, admin_id: int) -> List[Dict]: + """ + Given a target address (email, alias, or mailbox address), + return all decrypted bundle_json's that reference this address. + """ + if not target_address: + return [] + + normalized_target_address = target_address.lower() + mac_key_bytes = config.MAC_KEY + master_key_bytes = config.MASTER_ENC_KEY + + target_hmac = hmac.new( + mac_key_bytes, normalized_target_address.encode("utf-8"), sha256 + ).hexdigest() + lookup_entries: List[AbuserLookup] = AbuserLookup.filter( + AbuserLookup.hashed_address == target_hmac + ).all() + + if not lookup_entries: + return [] + + decrypted_bundles: List[Dict] = [] + + try: + k_target_address_derived = _derive_key_for_identifier( + master_key_bytes, normalized_target_address + ) + aesgcm_key_dec = AESGCM(k_target_address_derived) + except ValueError as ve_derive: + LOG.e( + f"Error deriving key for target_address '{normalized_target_address}': {ve_derive}" + ) + return [] + + for entry in lookup_entries: + blob_id = entry.abuser_data_id + encrypted_k_bundle_from_entry = entry.bundle_k + abuser_data_record: Optional[AbuserData] = AbuserData.filter_by( + id=blob_id + ).first() + + if not abuser_data_record: + LOG.e( + f"Error: No AbuserData found for blob_id {blob_id} linked to target_address '{normalized_target_address}'. Skipping." + ) + continue + + encrypted_main_bundle_data = abuser_data_record.encrypted_bundle + + try: + nonce_k_decryption = encrypted_k_bundle_from_entry[:12] + ciphertext_k_decryption = encrypted_k_bundle_from_entry[12:] + plaintext_k_bundle = aesgcm_key_dec.decrypt( + nonce_k_decryption, + ciphertext_k_decryption, + f"{constants.AEAD_AAD_DATA}.key".encode("utf-8"), + ) + aesgcm_bundle_dec = AESGCM(plaintext_k_bundle) + nonce_main_bundle = encrypted_main_bundle_data[:12] + ciphertext_main_bundle = encrypted_main_bundle_data[12:] + decrypted_bundle_json_bytes = aesgcm_bundle_dec.decrypt( + nonce_main_bundle, + ciphertext_main_bundle, + f"{constants.AEAD_AAD_DATA}.bundle".encode("utf-8"), + ) + bundle = json.loads(decrypted_bundle_json_bytes.decode("utf-8")) + decrypted_bundles.append(bundle) + except InvalidTag: + LOG.e( + f"Error: AEAD decryption failed for blob_id {blob_id} (either K_bundle or main bundle). InvalidTag. Target address: '{normalized_target_address}'." + ) + continue + except ValueError as ve: + LOG.e( + f"Error: Decryption ValueError for blob_id {blob_id}: {ve}. Target address: '{normalized_target_address}'." + ) + continue + except Exception: + LOG.e( + f"Error: General decryption exception for blob_id {blob_id}. Target address: '{normalized_target_address}'." + ) + continue + + emit_abuser_audit_log( + user_id=abuser_data_record.user_id, + admin_id=admin_id, + action=AbuserAuditLogAction.GetAbuserBundles, + message="The abuser bundle was requested.", + ) + + return decrypted_bundles diff --git a/app/admin_model.py b/app/admin_model.py index f54f783a..69a4d778 100644 --- a/app/admin_model.py +++ b/app/admin_model.py @@ -15,6 +15,7 @@ from flask_login import current_user from markupsafe import Markup from app import models, s3, config +from app.abuser_utils import mark_user_as_abuser, unmark_as_abusive_user from app.custom_domain_validation import ( CustomDomainValidation, DomainValidationResult, @@ -160,7 +161,9 @@ class UserAdmin(SLModelView): ) def action_disable_user(self, ids): for user in User.filter(User.id.in_(ids)): - user.disabled = True + mark_user_as_abuser( + user, f"An user {user.id} was marked as abuser.", current_user.id + ) flash(f"Disabled user {user.id}") AdminAuditLog.disable_user(current_user.id, user.id) @@ -174,7 +177,9 @@ class UserAdmin(SLModelView): ) def action_enable_user(self, ids): for user in User.filter(User.id.in_(ids)): - user.disabled = False + unmark_as_abusive_user( + user.id, f"An user {user.id} was unmarked as abuser.", current_user.id + ) flash(f"Enabled user {user.id}") AdminAuditLog.enable_user(current_user.id, user.id) diff --git a/app/api/views/auth.py b/app/api/views/auth.py index 3786de1d..69292968 100644 --- a/app/api/views/auth.py +++ b/app/api/views/auth.py @@ -9,6 +9,7 @@ from itsdangerous import Signer from app import email_utils from app.api.base import api_bp +from app.abuser_utils import check_if_abuser_email from app.config import FLASK_SECRET, DISABLE_REGISTRATION from app.dashboard.views.account_setting import send_reset_password_email from app.db import Session @@ -113,6 +114,12 @@ def auth_register(): ).send() return jsonify(error=f"cannot use {email} as personal inbox"), 400 + if check_if_abuser_email(email): + LOG.warn( + f"User with email {email} that was marked as abuser tried to register again" + ) + return jsonify(error=f"cannot use {email} as it was previously banned"), 400 + if not password or len(password) < 8: RegisterEvent(RegisterEvent.ActionType.failed, RegisterEvent.Source.api).send() return jsonify(error="password too short"), 400 diff --git a/app/auth/views/register.py b/app/auth/views/register.py index 1799a76f..874f0a26 100644 --- a/app/auth/views/register.py +++ b/app/auth/views/register.py @@ -5,9 +5,10 @@ from flask_wtf import FlaskForm from wtforms import StringField, validators from app import email_utils, config +from app.abuser_utils import check_if_abuser_email from app.auth.base import auth_bp -from app.config import CONNECT_WITH_PROTON, CONNECT_WITH_OIDC_ICON from app.auth.views.login_utils import get_referral +from app.config import CONNECT_WITH_PROTON, CONNECT_WITH_OIDC_ICON from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY from app.db import Session from app.email_utils import ( @@ -74,6 +75,8 @@ def register(): if not email_can_be_used_as_mailbox(email): flash("You cannot use this email address as your personal inbox.", "error") RegisterEvent(RegisterEvent.ActionType.email_in_use).send() + elif check_if_abuser_email(email): + flash("The email address provided is banned from registration.", "error") else: sanitized_email = sanitize_email(form.email.data) if personal_email_already_used(email) or personal_email_already_used( diff --git a/app/config.py b/app/config.py index 64ae82b6..e11756d5 100644 --- a/app/config.py +++ b/app/config.py @@ -686,3 +686,18 @@ AUDIT_LOG_MAX_DAYS = int(os.environ.get("AUDIT_LOG_MAX_DAYS", 30)) ALIAS_TRASH_DAYS = int(os.environ.get("ALIAS_TRASH_DAYS", 30)) ALLOWED_OAUTH_SCHEMES = get_env_csv("ALLOWED_OAUTH_SCHEMES", "auth.simplelogin,https") MAX_EMAIL_FORWARD_RECIPIENTS = int(os.environ.get("MAX_EMAIL_FORWARD_RECIPIENTS", 30)) + + +def read_hex_data(key: string, default: bytes) -> bytes: + data = os.environ.get(key) + + return bytes.fromhex(data) if data else default + + +MASTER_ENC_KEY = read_hex_data( + "MASTER_ENC_KEY_HEX", (FLASK_SECRET + "enckey").encode("utf-8") +) +MAC_KEY = read_hex_data("MAC_KEY_HEX", (FLASK_SECRET + "mackey").encode("utf-8")) +ABUSER_HKDF_SALT = read_hex_data( + "ABUSER_HKDF_SALT", (FLASK_SECRET + "absalt").encode("utf-8") +) diff --git a/app/constants.py b/app/constants.py index 48f44340..fa7fe0b5 100644 --- a/app/constants.py +++ b/app/constants.py @@ -2,6 +2,8 @@ import enum HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies" DMARC_RECORD = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s" +HKDF_INFO_TEMPLATE = "enc_key.ab.sl.proton.me:%s" +AEAD_AAD_DATA = "data.ab.sl.proton.me" class JobType(enum.Enum): diff --git a/app/models.py b/app/models.py index 295fb516..5aa1b8dd 100644 --- a/app/models.py +++ b/app/models.py @@ -357,6 +357,37 @@ class Fido(Base, ModelMixin): __table_args__ = (sa.Index("ix_fido_user_id", "user_id"),) +class AbuserData(Base, ModelMixin): + __tablename__ = "abuser_data" + + user_id = sa.Column(sa.Integer, nullable=False, index=True) + encrypted_bundle = sa.Column(sa.LargeBinary(), nullable=False) + + __table_args__ = (sa.Index("ix_abuser_data_id", "id"),) + + +class AbuserLookup(Base, ModelMixin): + __tablename__ = "abuser_lookup" + + hashed_address = sa.Column(sa.String(64), nullable=False, index=True) + abuser_data_id = sa.Column( + sa.Integer, + sa.ForeignKey("abuser_data.id", ondelete="cascade"), + nullable=False, + index=True, + ) + bundle_k = sa.Column(sa.LargeBinary(), nullable=False) + + +class AbuserAuditLog(Base, ModelMixin): + __tablename__ = "abuser_audit_log" + + user_id = sa.Column(sa.Integer, nullable=False, index=True) + admin_id = sa.Column(sa.Integer, nullable=True) + action = sa.Column(sa.String(255), nullable=False) + message = sa.Column(sa.Text, default=None, nullable=True) + + class User(Base, ModelMixin, UserMixin, PasswordOracle): __tablename__ = "users" diff --git a/migrations/versions/2025_051618_e38002759d8f_.py b/migrations/versions/2025_051618_e38002759d8f_.py new file mode 100644 index 00000000..8130f32c --- /dev/null +++ b/migrations/versions/2025_051618_e38002759d8f_.py @@ -0,0 +1,68 @@ +"""empty message + +Revision ID: e38002759d8f +Revises: 87da368d282b +Create Date: 2025-05-16 18:27:19.683673 + +""" +import sqlalchemy_utils +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e38002759d8f' +down_revision = '87da368d282b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('abuser_audit_log', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False), + sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('admin_id', sa.Integer(), nullable=True), + sa.Column('action', sa.String(length=255), nullable=False), + sa.Column('message', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_abuser_audit_log_user_id'), 'abuser_audit_log', ['user_id'], unique=False) + op.create_table('abuser_data', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False), + sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('encrypted_bundle', sa.LargeBinary(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('ix_abuser_data_id', 'abuser_data', ['id'], unique=False) + op.create_index(op.f('ix_abuser_data_user_id'), 'abuser_data', ['user_id'], unique=False) + op.create_table('abuser_lookup', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False), + sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True), + sa.Column('hashed_address', sa.String(length=64), nullable=False), + sa.Column('abuser_data_id', sa.Integer(), nullable=False), + sa.Column('bundle_k', sa.LargeBinary(), nullable=False), + sa.ForeignKeyConstraint(['abuser_data_id'], ['abuser_data.id'], ondelete='cascade'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_abuser_lookup_abuser_data_id'), 'abuser_lookup', ['abuser_data_id'], unique=False) + op.create_index(op.f('ix_abuser_lookup_hashed_address'), 'abuser_lookup', ['hashed_address'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_abuser_lookup_hashed_address'), table_name='abuser_lookup') + op.drop_index(op.f('ix_abuser_lookup_abuser_data_id'), table_name='abuser_lookup') + op.drop_table('abuser_lookup') + op.drop_index(op.f('ix_abuser_data_user_id'), table_name='abuser_data') + op.drop_index('ix_abuser_data_id', table_name='abuser_data') + op.drop_table('abuser_data') + op.drop_index(op.f('ix_abuser_audit_log_user_id'), table_name='abuser_audit_log') + op.drop_table('abuser_audit_log') + # ### end Alembic commands ### diff --git a/tests/test.env b/tests/test.env index a872df0c..6a56ef3f 100644 --- a/tests/test.env +++ b/tests/test.env @@ -76,4 +76,3 @@ ENABLE_ALL_REVERSE_ALIAS_REPLACEMENT=true MAX_NB_REVERSE_ALIAS_REPLACEMENT=200 MEM_STORE_URI=redis://localhost - diff --git a/tests/test_abuser_audit_log_utils.py b/tests/test_abuser_audit_log_utils.py new file mode 100644 index 00000000..529a4a1a --- /dev/null +++ b/tests/test_abuser_audit_log_utils.py @@ -0,0 +1,27 @@ +from typing import List + +from app.abuser_audit_log_utils import emit_abuser_audit_log, AbuserAuditLogAction +from app.models import AbuserAuditLog +from app.utils import random_string +from tests.utils import create_new_user + + +def test_emit_abuser_audit_log_for_random_data(): + user = create_new_user() + + message = random_string() + action = AbuserAuditLogAction.MarkAbuser + emit_abuser_audit_log( + user_id=user.id, + action=action, + message=message, + commit=True, + ) + + logs_for_user: List[AbuserAuditLog] = AbuserAuditLog.filter_by( + user_id=user.id, action=action.value + ).all() + assert len(logs_for_user) == 1 + assert logs_for_user[0].user_id == user.id + assert logs_for_user[0].action == action.value + assert logs_for_user[0].message == message diff --git a/tests/test_abuser_utils.py b/tests/test_abuser_utils.py new file mode 100644 index 00000000..1546ff2b --- /dev/null +++ b/tests/test_abuser_utils.py @@ -0,0 +1,566 @@ +import hashlib +import hmac +import secrets + +import pytest +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes as crypto_hashes +from cryptography.hazmat.primitives.kdf.hkdf import HKDF + +from app.abuser_utils import ( + mark_user_as_abuser, + check_if_abuser_email, + get_abuser_bundles_for_address, + unmark_as_abusive_user, + _derive_key_for_identifier, +) +from app import constants +from app.db import Session +from app.models import AbuserLookup, AbuserData, Alias, Mailbox, User +from tests.utils import random_email, create_new_user + +MOCK_MASTER_ENC_KEY = secrets.token_bytes(32) +MOCK_MAC_KEY = secrets.token_bytes(32) +MOCK_ABUSER_HKDF_SALT = secrets.token_bytes(32) + + +@pytest.fixture(autouse=True) +def mock_app_config(monkeypatch): + class MockConfig: + MASTER_ENC_KEY = MOCK_MASTER_ENC_KEY + MAC_KEY = MOCK_MAC_KEY + ABUSER_HKDF_SALT = MOCK_MASTER_ENC_KEY + + monkeypatch.setattr("app.abuser_utils.config", MockConfig) + + +def calculate_hmac(address: str) -> str: + normalized_address = address.lower() + + return hmac.new( + MOCK_MAC_KEY, normalized_address.encode("utf-8"), hashlib.sha256 + ).hexdigest() + + +def helper_derive_kek_for_identifier_test(identifier_address: str) -> bytes: + normalized_identifier = identifier_address.lower() + hkdf_info_bytes = (constants.HKDF_INFO_TEMPLATE % normalized_identifier).encode( + "utf-8" + ) + hkdf = HKDF( + algorithm=crypto_hashes.SHA256(), + length=32, + salt=None, + info=hkdf_info_bytes, + backend=default_backend(), + ) + + return hkdf.derive(MOCK_MASTER_ENC_KEY) + + +def get_lookup_count_for_address(address: str) -> int: + identifier_hmac = calculate_hmac(address) + + return ( + Session.query(AbuserLookup) + .filter(AbuserLookup.hashed_address == identifier_hmac) + .count() + ) + + +def test_blocked_address_is_denied(flask_client, monkeypatch): + owner_user_email = random_email() + owner_user = create_new_user(email=owner_user_email) + + if not owner_user.id: + Session.flush() + + bad_address_to_block = "troll@shtepan.com" + identifier_hmac = calculate_hmac(bad_address_to_block) + + a_data = AbuserData.create( + user_id=owner_user.id, + encrypted_bundle=secrets.token_bytes(64), + commit=True, + ) + AbuserLookup.create( + hashed_address=identifier_hmac, + abuser_data_id=a_data.id, + bundle_k=secrets.token_bytes(48), + commit=True, + ) + + assert check_if_abuser_email(bad_address_to_block) is not None + + Session.delete(a_data) + Session.delete(owner_user) + Session.commit() + + +def test_non_blocked_address_is_allowed(flask_client, monkeypatch): + safe_address = random_email() + identifier_hmac = calculate_hmac(safe_address) + + assert ( + AbuserLookup.filter(AbuserLookup.hashed_address == identifier_hmac).first() + is None + ) + assert check_if_abuser_email(safe_address) is None + + +def test_archive_basic_user(flask_client, monkeypatch): + user = create_new_user() + user_primary_email_normalized = user.email.lower() + + if not user.default_mailbox_id: + mb = Mailbox.create( + user_id=user.id, + email=random_email().lower(), + verified=True, + commit=True, + ) + user.default_mailbox_id = mb.id + Session.add(user) + Session.commit() + + alias1_email_normalized = random_email().lower() + Alias.create( + email=alias1_email_normalized, + user_id=user.id, + mailbox_id=user.default_mailbox_id, + commit=True, + ) + mailbox1 = Mailbox.get(user.default_mailbox_id) + + assert mailbox1 is not None + + mailbox1_email_normalized = mailbox1.email.lower() + Session.commit() + mark_user_as_abuser(user, "") + ab_data = AbuserData.filter_by(user_id=user.id).first() + + assert ab_data is not None + assert ab_data.encrypted_bundle is not None + + all_identifiers_to_check = { + user_primary_email_normalized, + alias1_email_normalized, + mailbox1_email_normalized, + } + + for identifier_str in all_identifiers_to_check: + if not identifier_str: + continue + + identifier_hmac = calculate_hmac(identifier_str) + lookup_entry = AbuserLookup.filter_by( + hashed_address=identifier_hmac, abuser_data_id=ab_data.id + ).first() + + assert lookup_entry is not None, f"Lookup entry missing for {identifier_str}" + assert lookup_entry.bundle_k is not None + + retrieved_bundles = get_abuser_bundles_for_address(identifier_str, -1) + + assert ( + len(retrieved_bundles) == 1 + ), f"Could not retrieve bundle for identifier: {identifier_str}" + + bundle = retrieved_bundles[0] + + assert bundle["account_id"] == str(user.id) + assert bundle["email"] == user_primary_email_normalized + assert any( + a["address"] == alias1_email_normalized for a in bundle.get("aliases", []) + ) + + +def test_archive_user_with_no_aliases_or_mailboxes(flask_client, monkeypatch): + user = create_new_user() + user_primary_email_normalized = user.email.lower() + Alias.filter_by(user_id=user.id).delete(synchronize_session=False) + Session.commit() + mark_user_as_abuser(user, "") + ab_data = AbuserData.filter_by(user_id=user.id).first() + + assert ab_data is not None + + retrieved_bundles = get_abuser_bundles_for_address( + user_primary_email_normalized, -1 + ) + + assert len(retrieved_bundles) == 1 + + bundle = retrieved_bundles[0] + + assert bundle["account_id"] == str(user.id) + assert bundle["email"] == user_primary_email_normalized + assert len(bundle.get("aliases", [])) == 0 + assert "mailboxes" in bundle + + +def test_duplicate_addresses_do_not_create_duplicate_lookups(flask_client, monkeypatch): + user = create_new_user() + duplicate_email_normalized = random_email().lower() + + if not user.default_mailbox_id: + mb = Mailbox.create( + user_id=user.id, email=random_email().lower(), verified=True, commit=True + ) + user.default_mailbox_id = mb.id + Session.add(user) + Session.commit() + + Alias.create( + email=duplicate_email_normalized, + user_id=user.id, + mailbox_id=user.default_mailbox_id, + commit=True, + ) + default_mb = Mailbox.get(user.default_mailbox_id) + + assert default_mb is not None + + default_mb.email = duplicate_email_normalized + Session.add(default_mb) + Session.commit() + mark_user_as_abuser(user, "") + identifier_hmac_duplicate = calculate_hmac(duplicate_email_normalized) + ab_data = AbuserData.filter_by(user_id=user.id).first() + + assert ab_data is not None + + matches_count_duplicate = AbuserLookup.filter_by( + hashed_address=identifier_hmac_duplicate, abuser_data_id=ab_data.id + ).count() + + assert matches_count_duplicate == 1 + + +def test_invalid_user_or_identifier_fails_gracefully(flask_client, monkeypatch): + with pytest.raises( + ValueError, match="Identifier address must be a non-empty string" + ): + _derive_key_for_identifier(MOCK_MASTER_ENC_KEY, None) + + with pytest.raises( + ValueError, match="Identifier address must be a non-empty string" + ): + _derive_key_for_identifier(MOCK_MASTER_ENC_KEY, "") + + user_obj_no_email = User(id=99999, email=None) + + with pytest.raises( + ValueError, match=f"User ID {user_obj_no_email.id} must have a primary email" + ): + mark_user_as_abuser(user_obj_no_email, "") + + +def test_can_decrypt_bundle_for_all_valid_identifiers(flask_client, monkeypatch): + user = create_new_user() + + if not user.default_mailbox_id: + mb = Mailbox.create( + user_id=user.id, + email=random_email().lower(), + verified=True, + commit=True, + ) + user.default_mailbox_id = mb.id + Session.add(user) + Session.commit() + + alias1_email_normalized = random_email().lower() + Alias.create( + email=alias1_email_normalized, + user_id=user.id, + mailbox_id=user.default_mailbox_id, + commit=True, + ) + mailbox1 = Mailbox.get(user.default_mailbox_id) + + assert mailbox1 is not None + + mailbox1_email_normalized = mailbox1.email.lower() + Session.commit() + mark_user_as_abuser(user, "") + ab_data = AbuserData.filter_by(user_id=user.id).first() + + assert ab_data is not None + + all_identifiers_to_check = { + user.email.lower(), + alias1_email_normalized, + mailbox1_email_normalized, + } + + for identifier_str in all_identifiers_to_check: + if not identifier_str: + continue + + retrieved_bundles = get_abuser_bundles_for_address(identifier_str, -1) + + assert len(retrieved_bundles) == 1, f"Failed for identifier: {identifier_str}" + + bundle = retrieved_bundles[0] + + assert bundle["account_id"] == str(user.id) + assert bundle["email"] == user.email.lower() + + +def test_db_rollback_on_error(monkeypatch, flask_client): + user = create_new_user() + original_commit = Session.commit + + def mock_commit_failure(): + raise RuntimeError("Simulated DB failure during commit") + + monkeypatch.setattr(Session, "commit", mock_commit_failure) + + with pytest.raises(RuntimeError, match="Simulated DB failure during commit"): + mark_user_as_abuser(user, "") + + monkeypatch.setattr(Session, "commit", original_commit) # Restore + Session.rollback() + + assert AbuserData.filter_by(user_id=user.id).first() is None + + identifier_hmac = calculate_hmac(user.email) + + assert AbuserLookup.filter_by(hashed_address=identifier_hmac).count() == 0 + + +def test_unarchive_abusive_user_removes_data(flask_client, monkeypatch): + user = create_new_user() + email_normalized = user.email.lower() + mark_user_as_abuser(user, "") + + assert AbuserData.filter_by(user_id=user.id).first() is not None + assert get_lookup_count_for_address(email_normalized) > 0 + + unmark_as_abusive_user(user.id, "") + + assert AbuserData.filter_by(user_id=user.id).first() is None + + assert get_lookup_count_for_address(email_normalized) == 0 + + +def test_unarchive_idempotent_on_missing_data(flask_client, monkeypatch): + user = create_new_user() + + assert AbuserData.filter_by(user_id=user.id).first() is None + + unmark_as_abusive_user(user.id, "") + + assert AbuserData.filter_by(user_id=user.id).first() is None + + +def test_abuser_data_deletion_cascades_to_lookup(flask_client, monkeypatch): + user = create_new_user() + mark_user_as_abuser(user, "") + ab_data = AbuserData.filter_by(user_id=user.id).first() + + assert ab_data is not None + + abuser_data_pk_id = ab_data.id + + assert AbuserLookup.filter_by(abuser_data_id=abuser_data_pk_id).count() > 0 + + Session.delete(ab_data) + Session.commit() + + assert AbuserLookup.filter_by(abuser_data_id=abuser_data_pk_id).count() == 0 + + +def test_archive_then_unarchive_then_rearchive_is_consistent(flask_client, monkeypatch): + user = create_new_user() + mark_user_as_abuser(user, "") + ab_data1 = AbuserData.filter_by(user_id=user.id).first() + + assert ab_data1 is not None + + unmark_as_abusive_user(user.id, "") + + assert AbuserData.filter_by(user_id=user.id).first() is None + + mark_user_as_abuser(user, "") + ab_data2 = AbuserData.filter_by(user_id=user.id).first() + + assert ab_data2 is not None + assert ab_data2.id != ab_data1.id + + +def test_get_abuser_bundles_returns_bundle_for_primary_email(flask_client, monkeypatch): + user = create_new_user() + email_normalized = user.email.lower() + mark_user_as_abuser(user, "") + bundles = get_abuser_bundles_for_address(email_normalized, -1) + + assert len(bundles) == 1 + + bundle = bundles[0] + + assert bundle["email"] == email_normalized + assert bundle["account_id"] == str(user.id) + assert "aliases" in bundle + assert "mailboxes" in bundle + + +def test_get_abuser_bundles_no_match_returns_empty(flask_client, monkeypatch): + bundles = get_abuser_bundles_for_address("bohdan@shtepan.com", -1) + assert bundles == [] + + +def test_get_abuser_bundles_from_alias_address(flask_client, monkeypatch): + user = create_new_user() + + if not user.default_mailbox_id: + mb = Mailbox.create( + user_id=user.id, email=random_email().lower(), verified=True, commit=True + ) + user.default_mailbox_id = mb.id + Session.add(user) + Session.commit() + + alias_email_normalized = random_email().lower() + Alias.create( + email=alias_email_normalized, + user_id=user.id, + mailbox_id=user.default_mailbox_id, + commit=True, + ) + mark_user_as_abuser(user, "") + results = get_abuser_bundles_for_address(alias_email_normalized, -1) + + assert len(results) == 1 + + bundle = results[0] + + assert bundle["email"] == user.email.lower() + assert any( + a["address"] == alias_email_normalized for a in bundle.get("aliases", []) + ) + + +def test_get_abuser_bundles_from_mailbox_address(flask_client, monkeypatch): + user = create_new_user() + mailbox = Mailbox.get(user.default_mailbox_id) if user.default_mailbox_id else None + + if not mailbox or not mailbox.email: + mailbox_email_for_test = random_email().lower() + + if mailbox: + mailbox.email = mailbox_email_for_test + else: + mailbox = Mailbox.create( + user_id=user.id, + email=mailbox_email_for_test, + verified=True, + commit=False, + ) + Session.flush() + user.default_mailbox_id = mailbox.id + + Session.add(mailbox) + + if not user.default_mailbox_id: + Session.add(user) + + Session.commit() + + current_mailbox_email_normalized = mailbox.email.lower() + mark_user_as_abuser(user, "") + + results = get_abuser_bundles_for_address(current_mailbox_email_normalized, -1) + + assert len(results) == 1 + + bundle = results[0] + + assert bundle["email"] == user.email.lower() + assert any( + m["address"] == current_mailbox_email_normalized + for m in bundle.get("mailboxes", []) + ) + + +def test_get_abuser_bundles_with_corrupt_encrypted_k_bundle_is_skipped( + flask_client, monkeypatch +): + user = create_new_user() + mark_user_as_abuser(user, "") + identifier_hmac = calculate_hmac(user.email) + lookup_entry = AbuserLookup.filter_by(hashed_address=identifier_hmac).first() + + assert lookup_entry is not None + + original_key_data = lookup_entry.bundle_k + corrupted_key_data = secrets.token_bytes(len(original_key_data)) + lookup_entry.bundle_k = corrupted_key_data + Session.add(lookup_entry) + Session.commit() + bundles = get_abuser_bundles_for_address(user.email, -1) + + assert bundles == [] + + lookup_entry.bundle_k = original_key_data + Session.add(lookup_entry) + Session.commit() + + +def test_get_abuser_bundles_with_corrupt_main_bundle_is_skipped( + flask_client, monkeypatch +): + user = create_new_user() + mark_user_as_abuser(user, "") + ab_data = AbuserData.filter_by(user_id=user.id).first() + + assert ab_data is not None + + original_main_bundle = ab_data.encrypted_bundle + corrupted_main_bundle = secrets.token_bytes(len(original_main_bundle)) + ab_data.encrypted_bundle = corrupted_main_bundle + + Session.add(ab_data) + Session.commit() + + bundles = get_abuser_bundles_for_address(user.email, -1) + + assert bundles == [] + + ab_data.encrypted_bundle = original_main_bundle + Session.add(ab_data) + Session.commit() + + +def test_archive_and_fetch_flow_end_to_end(flask_client, monkeypatch): + user = create_new_user() + mailbox = Mailbox.get(user.default_mailbox_id) if user.default_mailbox_id else None + + if not mailbox: + mailbox_email_for_test = random_email().lower() + mailbox = Mailbox.create( + user_id=user.id, + email=mailbox_email_for_test, + verified=True, + commit=True, + ) + user.default_mailbox_id = mailbox.id + Session.add(user) + Session.commit() + + current_mailbox_email_normalized = mailbox.email.lower() + mark_user_as_abuser(user, "") + bundles = get_abuser_bundles_for_address(user.email, -1) + + assert len(bundles) == 1 + + data = bundles[0] + + assert data is not None + assert data["email"] == user.email.lower() + assert data["account_id"] == str(user.id) + assert any( + mb["address"] == current_mailbox_email_normalized + for mb in data.get("mailboxes", []) + )