diff --git a/app_api/api/api/api_v1/schemas/base.py b/app_api/api/api/api_v1/schemas/base.py deleted file mode 100644 index e38cee3..0000000 --- a/app_api/api/api/api_v1/schemas/base.py +++ /dev/null @@ -1,6 +0,0 @@ -from yt_shared.schemas.base import RealBaseModel - - -class BaseOrmModel(RealBaseModel): - class Config(RealBaseModel.Config): - orm_mode = True diff --git a/app_api/api/api/api_v1/schemas/task.py b/app_api/api/api/api_v1/schemas/task.py index c618f3b..d1a95ac 100644 --- a/app_api/api/api/api_v1/schemas/task.py +++ b/app_api/api/api/api_v1/schemas/task.py @@ -3,9 +3,7 @@ from datetime import datetime from pydantic import StrictBool, StrictFloat, StrictInt, StrictStr from yt_shared.enums import DownMediaType, TaskSource, TaskStatus -from yt_shared.schemas.base import RealBaseModel - -from api.api.api_v1.schemas.base import BaseOrmModel +from yt_shared.schemas.base import BaseOrmModel, RealBaseModel class CacheSchema(BaseOrmModel): diff --git a/app_api/api/core/services/task.py b/app_api/api/core/services/task.py index c64cc78..40d4134 100644 --- a/app_api/api/core/services/task.py +++ b/app_api/api/core/services/task.py @@ -34,7 +34,7 @@ class TaskService: async def get_latest_task(self, include_meta: bool) -> Task: schema = self._get_schema(include_meta) task = await self._repository.get_latest_task(include_meta) - return schema.from_orm(task) + return schema.model_validate(task) async def get_task( self, @@ -43,7 +43,7 @@ class TaskService: ) -> Task: schema = self._get_schema(include_meta) task = await self._repository.get_task(id, include_meta) - return schema.from_orm(task) + return schema.model_validate(task) async def get_all_tasks( self, @@ -56,7 +56,7 @@ class TaskService: tasks = await self._repository.get_all_tasks( include_meta, status, limit, offset ) - return [schema.from_orm(task) for task in tasks] + return [schema.model_validate(task) for task in tasks] @staticmethod async def create_task_non_db( @@ -79,4 +79,4 @@ class TaskService: return CreateTaskOut(id=task_id, url=task.url, added_at=added_at, source=source) async def get_stats(self): - return TasksStatsSchema.from_orm(await self._repository.get_stats()) + return TasksStatsSchema.model_validate(await self._repository.get_stats()) diff --git a/app_bot/bot/core/config/schema.py b/app_bot/bot/core/config/schema.py index ca42bb8..4865a29 100644 --- a/app_bot/bot/core/config/schema.py +++ b/app_bot/bot/core/config/schema.py @@ -2,9 +2,10 @@ from pydantic import ( StrictBool, StrictInt, StrictStr, - constr, - validator, + StringConstraints, + field_validator, ) +from typing_extensions import Annotated from yt_shared.enums import DownMediaType from yt_shared.schemas.base import RealBaseModel @@ -48,10 +49,6 @@ class UserSchema(AnonymousUserSchema): return False -def _change_type(values: list[int]) -> list[AnonymousUserSchema]: - return [AnonymousUserSchema(id=id_) for id_ in values] - - class ApiSchema(RealBaseModel): upload_video_file: StrictBool upload_video_max_file_size: StrictInt @@ -59,14 +56,19 @@ class ApiSchema(RealBaseModel): silent: StrictBool video_caption: VideoCaptionSchema - _transform_chat_ids = validator('upload_to_chat_ids', pre=True)(_change_type) + @field_validator('upload_to_chat_ids', mode='before') + @classmethod + def transform_chat_ids(cls, values: list[int]) -> list[AnonymousUserSchema]: + return [AnonymousUserSchema(id=id_) for id_ in values] class TelegramSchema(RealBaseModel): api_id: StrictInt api_hash: StrictStr token: StrictStr - lang_code: constr(regex=_LANG_CODE_REGEX, to_lower=True) + lang_code: Annotated[ + str, StringConstraints(pattern=_LANG_CODE_REGEX, to_lower=True) + ] max_upload_tasks: StrictInt url_validation_regexes: list[str] allowed_users: list[UserSchema] diff --git a/yt_shared/requirements_shared.txt b/yt_shared/requirements_shared.txt index f89aaf4..bc056bc 100644 --- a/yt_shared/requirements_shared.txt +++ b/yt_shared/requirements_shared.txt @@ -3,5 +3,6 @@ SQLAlchemy==2.0.19 aio-pika==9.2.1 aiohttp==3.8.5 asyncpg==0.28.0 -pydantic==1.10.10 +pydantic==2.1.1 +pydantic_settings==2.0.2 uvloop==0.17.0 diff --git a/yt_shared/yt_shared/config.py b/yt_shared/yt_shared/config.py index 36f7e94..4027098 100644 --- a/yt_shared/yt_shared/config.py +++ b/yt_shared/yt_shared/config.py @@ -1,7 +1,8 @@ import logging from typing import KeysView -from pydantic import BaseSettings, validator +from pydantic import field_validator +from pydantic_settings import BaseSettings class Settings(BaseSettings): @@ -44,8 +45,8 @@ class Settings(BaseSettings): TMP_DOWNLOAD_DIR: str TMP_DOWNLOADED_DIR: str + @field_validator('LOG_LEVEL') @classmethod - @validator('LOG_LEVEL') def validate_log_level_value(cls, value: str) -> str: valid_values: KeysView[str] = logging._nameToLevel.keys() # noqa if value not in valid_values: diff --git a/yt_shared/yt_shared/rabbit/publisher.py b/yt_shared/yt_shared/rabbit/publisher.py index 79ef6bf..5977b65 100644 --- a/yt_shared/yt_shared/rabbit/publisher.py +++ b/yt_shared/yt_shared/rabbit/publisher.py @@ -29,7 +29,7 @@ class RmqPublisher(metaclass=Singleton): return isinstance(confirm, Basic.Ack) async def send_for_download(self, media_payload: InbMediaPayload) -> bool: - message = aio_pika.Message(body=media_payload.json().encode()) + message = aio_pika.Message(body=media_payload.model_dump_json().encode()) exchange = self._rabbit_mq.exchanges[INPUT_EXCHANGE] confirm = await exchange.publish( message, routing_key=INPUT_QUEUE, mandatory=True @@ -40,14 +40,14 @@ class RmqPublisher(metaclass=Singleton): self, error_payload: ErrorDownloadPayload | ErrorGeneralPayload ) -> bool: err_exchange = self._rabbit_mq.exchanges[ERROR_EXCHANGE] - err_message = aio_pika.Message(body=error_payload.json().encode()) + err_message = aio_pika.Message(body=error_payload.model_dump_json().encode()) confirm = await err_exchange.publish( err_message, routing_key=ERROR_QUEUE, mandatory=True ) return self._is_sent(confirm) async def send_download_finished(self, success_payload: SuccessPayload) -> bool: - message = aio_pika.Message(body=success_payload.json().encode()) + message = aio_pika.Message(body=success_payload.model_dump_json().encode()) exchange = self._rabbit_mq.exchanges[SUCCESS_EXCHANGE] confirm = await exchange.publish( message, routing_key=SUCCESS_QUEUE, mandatory=True diff --git a/yt_shared/yt_shared/schemas/base.py b/yt_shared/yt_shared/schemas/base.py index a14824b..54b55ab 100644 --- a/yt_shared/yt_shared/schemas/base.py +++ b/yt_shared/yt_shared/schemas/base.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Extra +from pydantic import BaseModel, ConfigDict from yt_shared.enums import RabbitPayloadType @@ -6,14 +6,11 @@ from yt_shared.enums import RabbitPayloadType class RealBaseModel(BaseModel): """Base Pydantic model. All models should inherit from this.""" - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra='forbid') - def json(self, *args, **kwargs) -> str: - """By default, dump without whitespaces.""" - if 'separators' not in kwargs: - kwargs['separators'] = (',', ':') - return super().json(*args, **kwargs) + +class BaseOrmModel(RealBaseModel): + model_config = ConfigDict(from_attributes=True, **RealBaseModel.model_config) class BaseRabbitPayloadModel(RealBaseModel): diff --git a/yt_shared/yt_shared/schemas/media.py b/yt_shared/yt_shared/schemas/media.py index 9e03b4b..315616e 100644 --- a/yt_shared/yt_shared/schemas/media.py +++ b/yt_shared/yt_shared/schemas/media.py @@ -8,7 +8,7 @@ from pydantic import ( StrictFloat, StrictInt, StrictStr, - root_validator, + model_validator, ) from yt_shared.enums import DownMediaType, MediaFileType, TaskSource, TelegramChatType @@ -19,7 +19,7 @@ from yt_shared.utils.common import format_bytes class InbMediaPayload(RealBaseModel): """RabbitMQ incoming media payload from Telegram Bot or API service.""" - id: uuid.UUID | None + id: uuid.UUID | None = None from_chat_id: StrictInt | None from_chat_type: TelegramChatType | None from_user_id: StrictInt | None @@ -68,7 +68,8 @@ class Video(BaseMedia): height: int | None = None thumb_path: StrictStr | None = None - @root_validator(pre=False) + @model_validator(mode='before') + @classmethod def _set_fields(cls, values: dict) -> dict: if not values['thumb_name']: values['thumb_name'] = f'{values["filename"]}-thumb.jpg' @@ -85,7 +86,8 @@ class DownMedia(RealBaseModel): root_path: StrictStr meta: dict - @root_validator(pre=True) + @model_validator(mode='before') + @classmethod def _validate(cls, values: dict) -> dict: if values['audio'] is None and values['video'] is None: raise ValueError('Provide audio, video or both.') diff --git a/yt_shared/yt_shared/schemas/ytdlp.py b/yt_shared/yt_shared/schemas/ytdlp.py index 951e8d7..5177971 100644 --- a/yt_shared/yt_shared/schemas/ytdlp.py +++ b/yt_shared/yt_shared/schemas/ytdlp.py @@ -1,8 +1,8 @@ import datetime -from pydantic import Field, StrictStr, validator +from pydantic import Field, StrictStr, field_validator -from yt_shared.schemas.base import RealBaseModel +from yt_shared.schemas.base import BaseOrmModel, RealBaseModel def _remove_microseconds(dt_obj: datetime.datetime) -> datetime.datetime: @@ -13,24 +13,21 @@ class LatestVersion(RealBaseModel): version: StrictStr retrieved_at: datetime.datetime + @field_validator('retrieved_at', mode='before') @classmethod - @validator('retrieved_at', pre=True) def remove_microseconds(cls, value: datetime.datetime) -> datetime.datetime: return _remove_microseconds(value) -class CurrentVersion(RealBaseModel): +class CurrentVersion(BaseOrmModel): version: StrictStr = Field(..., alias='current_version') updated_at: datetime.datetime + @field_validator('updated_at', mode='before') @classmethod - @validator('updated_at', pre=True) def remove_microseconds(cls, value: datetime.datetime) -> datetime.datetime: return _remove_microseconds(value) - class Config(RealBaseModel.Config): - orm_mode = True - class VersionContext(RealBaseModel): latest: LatestVersion diff --git a/yt_shared/yt_shared/ytdlp/version_checker.py b/yt_shared/yt_shared/ytdlp/version_checker.py index a8b0bcf..c48d84e 100644 --- a/yt_shared/yt_shared/ytdlp/version_checker.py +++ b/yt_shared/yt_shared/ytdlp/version_checker.py @@ -35,4 +35,4 @@ class YtdlpVersionChecker: async def get_current_version(self, db: AsyncSession) -> CurrentVersion: ytdlp_ = await self._ytdlp_repository.get_current_version(db) self._log.info('Current yt-dlp version: %s', ytdlp_.current_version) - return CurrentVersion.from_orm(ytdlp_) + return CurrentVersion.model_validate(ytdlp_)