db operation

This commit is contained in:
Benny 2024-11-30 15:27:11 +01:00
parent 991a03630b
commit 580c49955e
No known key found for this signature in database
GPG key ID: 6CD0DBDA5235D481
3 changed files with 59 additions and 16 deletions

View file

@ -7,7 +7,6 @@ RUN pdm install
FROM python:3.12-alpine AS runner
WORKDIR /app
ENV TZ=Europe/Stockholm
RUN apk update && apk add --no-cache ffmpeg aria2
COPY --from=pybuilder /build/.venv/lib/ /usr/local/lib/

View file

@ -8,6 +8,7 @@ from typing import Literal
from sqlalchemy import Column, Enum, Float, ForeignKey, Integer, String, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.dialects.mysql import JSON
# ytdlbot - model.py
@ -19,11 +20,10 @@ class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(String(255), unique=True, nullable=False) # telegram user id
username = Column(String(255), nullable=True)
name = Column(String(255), nullable=True)
free = Column(Integer, default=0)
user_id = Column(Integer, unique=True, nullable=False) # telegram user id
free = Column(Integer, default=5)
paid = Column(Integer, default=0)
config = Column(JSON)
settings = relationship("Setting", back_populates="user", cascade="all, delete-orphan")
payments = relationship("Payment", back_populates="user", cascade="all, delete-orphan")
@ -33,8 +33,8 @@ class Setting(Base):
__tablename__ = "settings"
id = Column(Integer, primary_key=True, autoincrement=True)
download = Column(Enum("high", "medium", "low", "audio", "custom"), nullable=False)
upload = Column(Enum("video", "audio", "document"), nullable=False)
download = Column(Enum("high", "medium", "low", "audio", "custom"), nullable=False, default="high")
upload = Column(Enum("video", "audio", "document"), nullable=False, default="video")
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
user = relationship("User", back_populates="settings")
@ -82,36 +82,80 @@ def session_manager():
def get_download_settings(uid) -> Literal["high", "medium", "low", "audio", "custom"]:
with session_manager() as session:
return "custom"
data = session.query(Setting).filter(Setting.user_id == uid).first()
if data:
return data.download
return "high"
def get_upload_settings(uid) -> Literal["video", "audio", "document"]:
with session_manager() as session:
data = session.query(Setting).filter(Setting.user_id == uid).first()
if data:
return data.upload
return "video"
def set_user_settings(uid: int, key: str, value: str):
# set download or upload settings
pass
with session_manager() as session:
# upsert
setting = session.query(Setting).filter(Setting.user_id == uid).first()
if setting:
setattr(setting, key, value)
else:
session.add(Setting(user_id=uid, **{key: value}))
def get_free_quota(uid: int):
pass
with session_manager() as session:
data = session.query(User).filter(User.user_id == uid).first()
if data:
return data.free
return 5
def get_paid_quota(uid: int):
if not os.getenv("ENABLE_VIP"):
return math.inf
if os.getenv("ENABLE_VIP"):
with session_manager() as session:
data = session.query(User).filter(User.user_id == uid).first()
if data:
return data.paid
return 0
return math.inf
def reset_free_quota(uid: int):
pass
with session_manager() as session:
data = session.query(User).filter(User.user_id == uid).first()
if data:
data.free = 5
def add_paid_quota(uid: int, amount: int):
pass
with session_manager() as session:
data = session.query(User).filter(User.user_id == uid).first()
if data:
data.paid += amount
def use_quota(uid: int):
# use free first, then paid
pass
with session_manager() as session:
user = session.query(User).filter(User.user_id == uid).first()
if user:
if user.free > 0:
user.free -= 1
elif user.paid > 0:
user.paid -= 1
else:
raise Exception("Quota exhausted")
def init_user(uid: int):
with session_manager() as session:
user = session.query(User).filter(User.user_id == uid).first()
if not user:
session.add(User(user_id=uid))

View file

@ -61,7 +61,7 @@ class BaseDownloader(ABC):
self._user_id = user_id
self._id = _id
self._tempdir = tempfile.TemporaryDirectory(prefix="ytdl-")
self._bot_msg = self._client.get_messages(self._user_id, self._id)
self._bot_msg: Types.Message = self._client.get_messages(self._user_id, self._id)
def __del__(self):
self._tempdir.cleanup()