mirror of
https://github.com/tropicoo/yt-dlp-bot.git
synced 2025-03-03 02:03:30 +08:00
Migrate Pydantic 1 to Pydantic 2
This commit is contained in:
parent
d9a34af46d
commit
a3536073e6
11 changed files with 40 additions and 48 deletions
|
@ -1,6 +0,0 @@
|
|||
from yt_shared.schemas.base import RealBaseModel
|
||||
|
||||
|
||||
class BaseOrmModel(RealBaseModel):
|
||||
class Config(RealBaseModel.Config):
|
||||
orm_mode = True
|
|
@ -3,9 +3,7 @@ from datetime import datetime
|
|||
|
||||
from pydantic import StrictBool, StrictFloat, StrictInt, StrictStr
|
||||
from yt_shared.enums import DownMediaType, TaskSource, TaskStatus
|
||||
from yt_shared.schemas.base import RealBaseModel
|
||||
|
||||
from api.api.api_v1.schemas.base import BaseOrmModel
|
||||
from yt_shared.schemas.base import BaseOrmModel, RealBaseModel
|
||||
|
||||
|
||||
class CacheSchema(BaseOrmModel):
|
||||
|
|
|
@ -34,7 +34,7 @@ class TaskService:
|
|||
async def get_latest_task(self, include_meta: bool) -> Task:
|
||||
schema = self._get_schema(include_meta)
|
||||
task = await self._repository.get_latest_task(include_meta)
|
||||
return schema.from_orm(task)
|
||||
return schema.model_validate(task)
|
||||
|
||||
async def get_task(
|
||||
self,
|
||||
|
@ -43,7 +43,7 @@ class TaskService:
|
|||
) -> Task:
|
||||
schema = self._get_schema(include_meta)
|
||||
task = await self._repository.get_task(id, include_meta)
|
||||
return schema.from_orm(task)
|
||||
return schema.model_validate(task)
|
||||
|
||||
async def get_all_tasks(
|
||||
self,
|
||||
|
@ -56,7 +56,7 @@ class TaskService:
|
|||
tasks = await self._repository.get_all_tasks(
|
||||
include_meta, status, limit, offset
|
||||
)
|
||||
return [schema.from_orm(task) for task in tasks]
|
||||
return [schema.model_validate(task) for task in tasks]
|
||||
|
||||
@staticmethod
|
||||
async def create_task_non_db(
|
||||
|
@ -79,4 +79,4 @@ class TaskService:
|
|||
return CreateTaskOut(id=task_id, url=task.url, added_at=added_at, source=source)
|
||||
|
||||
async def get_stats(self):
|
||||
return TasksStatsSchema.from_orm(await self._repository.get_stats())
|
||||
return TasksStatsSchema.model_validate(await self._repository.get_stats())
|
||||
|
|
|
@ -2,9 +2,10 @@ from pydantic import (
|
|||
StrictBool,
|
||||
StrictInt,
|
||||
StrictStr,
|
||||
constr,
|
||||
validator,
|
||||
StringConstraints,
|
||||
field_validator,
|
||||
)
|
||||
from typing_extensions import Annotated
|
||||
from yt_shared.enums import DownMediaType
|
||||
from yt_shared.schemas.base import RealBaseModel
|
||||
|
||||
|
@ -48,10 +49,6 @@ class UserSchema(AnonymousUserSchema):
|
|||
return False
|
||||
|
||||
|
||||
def _change_type(values: list[int]) -> list[AnonymousUserSchema]:
|
||||
return [AnonymousUserSchema(id=id_) for id_ in values]
|
||||
|
||||
|
||||
class ApiSchema(RealBaseModel):
|
||||
upload_video_file: StrictBool
|
||||
upload_video_max_file_size: StrictInt
|
||||
|
@ -59,14 +56,19 @@ class ApiSchema(RealBaseModel):
|
|||
silent: StrictBool
|
||||
video_caption: VideoCaptionSchema
|
||||
|
||||
_transform_chat_ids = validator('upload_to_chat_ids', pre=True)(_change_type)
|
||||
@field_validator('upload_to_chat_ids', mode='before')
|
||||
@classmethod
|
||||
def transform_chat_ids(cls, values: list[int]) -> list[AnonymousUserSchema]:
|
||||
return [AnonymousUserSchema(id=id_) for id_ in values]
|
||||
|
||||
|
||||
class TelegramSchema(RealBaseModel):
|
||||
api_id: StrictInt
|
||||
api_hash: StrictStr
|
||||
token: StrictStr
|
||||
lang_code: constr(regex=_LANG_CODE_REGEX, to_lower=True)
|
||||
lang_code: Annotated[
|
||||
str, StringConstraints(pattern=_LANG_CODE_REGEX, to_lower=True)
|
||||
]
|
||||
max_upload_tasks: StrictInt
|
||||
url_validation_regexes: list[str]
|
||||
allowed_users: list[UserSchema]
|
||||
|
|
|
@ -3,5 +3,6 @@ SQLAlchemy==2.0.19
|
|||
aio-pika==9.2.1
|
||||
aiohttp==3.8.5
|
||||
asyncpg==0.28.0
|
||||
pydantic==1.10.10
|
||||
pydantic==2.1.1
|
||||
pydantic_settings==2.0.2
|
||||
uvloop==0.17.0
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
from typing import KeysView
|
||||
|
||||
from pydantic import BaseSettings, validator
|
||||
from pydantic import field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
|
@ -44,8 +45,8 @@ class Settings(BaseSettings):
|
|||
TMP_DOWNLOAD_DIR: str
|
||||
TMP_DOWNLOADED_DIR: str
|
||||
|
||||
@field_validator('LOG_LEVEL')
|
||||
@classmethod
|
||||
@validator('LOG_LEVEL')
|
||||
def validate_log_level_value(cls, value: str) -> str:
|
||||
valid_values: KeysView[str] = logging._nameToLevel.keys() # noqa
|
||||
if value not in valid_values:
|
||||
|
|
|
@ -29,7 +29,7 @@ class RmqPublisher(metaclass=Singleton):
|
|||
return isinstance(confirm, Basic.Ack)
|
||||
|
||||
async def send_for_download(self, media_payload: InbMediaPayload) -> bool:
|
||||
message = aio_pika.Message(body=media_payload.json().encode())
|
||||
message = aio_pika.Message(body=media_payload.model_dump_json().encode())
|
||||
exchange = self._rabbit_mq.exchanges[INPUT_EXCHANGE]
|
||||
confirm = await exchange.publish(
|
||||
message, routing_key=INPUT_QUEUE, mandatory=True
|
||||
|
@ -40,14 +40,14 @@ class RmqPublisher(metaclass=Singleton):
|
|||
self, error_payload: ErrorDownloadPayload | ErrorGeneralPayload
|
||||
) -> bool:
|
||||
err_exchange = self._rabbit_mq.exchanges[ERROR_EXCHANGE]
|
||||
err_message = aio_pika.Message(body=error_payload.json().encode())
|
||||
err_message = aio_pika.Message(body=error_payload.model_dump_json().encode())
|
||||
confirm = await err_exchange.publish(
|
||||
err_message, routing_key=ERROR_QUEUE, mandatory=True
|
||||
)
|
||||
return self._is_sent(confirm)
|
||||
|
||||
async def send_download_finished(self, success_payload: SuccessPayload) -> bool:
|
||||
message = aio_pika.Message(body=success_payload.json().encode())
|
||||
message = aio_pika.Message(body=success_payload.model_dump_json().encode())
|
||||
exchange = self._rabbit_mq.exchanges[SUCCESS_EXCHANGE]
|
||||
confirm = await exchange.publish(
|
||||
message, routing_key=SUCCESS_QUEUE, mandatory=True
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel, Extra
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from yt_shared.enums import RabbitPayloadType
|
||||
|
||||
|
@ -6,14 +6,11 @@ from yt_shared.enums import RabbitPayloadType
|
|||
class RealBaseModel(BaseModel):
|
||||
"""Base Pydantic model. All models should inherit from this."""
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
|
||||
def json(self, *args, **kwargs) -> str:
|
||||
"""By default, dump without whitespaces."""
|
||||
if 'separators' not in kwargs:
|
||||
kwargs['separators'] = (',', ':')
|
||||
return super().json(*args, **kwargs)
|
||||
|
||||
class BaseOrmModel(RealBaseModel):
|
||||
model_config = ConfigDict(from_attributes=True, **RealBaseModel.model_config)
|
||||
|
||||
|
||||
class BaseRabbitPayloadModel(RealBaseModel):
|
||||
|
|
|
@ -8,7 +8,7 @@ from pydantic import (
|
|||
StrictFloat,
|
||||
StrictInt,
|
||||
StrictStr,
|
||||
root_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from yt_shared.enums import DownMediaType, MediaFileType, TaskSource, TelegramChatType
|
||||
|
@ -19,7 +19,7 @@ from yt_shared.utils.common import format_bytes
|
|||
class InbMediaPayload(RealBaseModel):
|
||||
"""RabbitMQ incoming media payload from Telegram Bot or API service."""
|
||||
|
||||
id: uuid.UUID | None
|
||||
id: uuid.UUID | None = None
|
||||
from_chat_id: StrictInt | None
|
||||
from_chat_type: TelegramChatType | None
|
||||
from_user_id: StrictInt | None
|
||||
|
@ -68,7 +68,8 @@ class Video(BaseMedia):
|
|||
height: int | None = None
|
||||
thumb_path: StrictStr | None = None
|
||||
|
||||
@root_validator(pre=False)
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def _set_fields(cls, values: dict) -> dict:
|
||||
if not values['thumb_name']:
|
||||
values['thumb_name'] = f'{values["filename"]}-thumb.jpg'
|
||||
|
@ -85,7 +86,8 @@ class DownMedia(RealBaseModel):
|
|||
root_path: StrictStr
|
||||
meta: dict
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def _validate(cls, values: dict) -> dict:
|
||||
if values['audio'] is None and values['video'] is None:
|
||||
raise ValueError('Provide audio, video or both.')
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import datetime
|
||||
|
||||
from pydantic import Field, StrictStr, validator
|
||||
from pydantic import Field, StrictStr, field_validator
|
||||
|
||||
from yt_shared.schemas.base import RealBaseModel
|
||||
from yt_shared.schemas.base import BaseOrmModel, RealBaseModel
|
||||
|
||||
|
||||
def _remove_microseconds(dt_obj: datetime.datetime) -> datetime.datetime:
|
||||
|
@ -13,24 +13,21 @@ class LatestVersion(RealBaseModel):
|
|||
version: StrictStr
|
||||
retrieved_at: datetime.datetime
|
||||
|
||||
@field_validator('retrieved_at', mode='before')
|
||||
@classmethod
|
||||
@validator('retrieved_at', pre=True)
|
||||
def remove_microseconds(cls, value: datetime.datetime) -> datetime.datetime:
|
||||
return _remove_microseconds(value)
|
||||
|
||||
|
||||
class CurrentVersion(RealBaseModel):
|
||||
class CurrentVersion(BaseOrmModel):
|
||||
version: StrictStr = Field(..., alias='current_version')
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_validator('updated_at', mode='before')
|
||||
@classmethod
|
||||
@validator('updated_at', pre=True)
|
||||
def remove_microseconds(cls, value: datetime.datetime) -> datetime.datetime:
|
||||
return _remove_microseconds(value)
|
||||
|
||||
class Config(RealBaseModel.Config):
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class VersionContext(RealBaseModel):
|
||||
latest: LatestVersion
|
||||
|
|
|
@ -35,4 +35,4 @@ class YtdlpVersionChecker:
|
|||
async def get_current_version(self, db: AsyncSession) -> CurrentVersion:
|
||||
ytdlp_ = await self._ytdlp_repository.get_current_version(db)
|
||||
self._log.info('Current yt-dlp version: %s', ytdlp_.current_version)
|
||||
return CurrentVersion.from_orm(ytdlp_)
|
||||
return CurrentVersion.model_validate(ytdlp_)
|
||||
|
|
Loading…
Reference in a new issue