Migrate Pydantic 1 to Pydantic 2

This commit is contained in:
Taras Terletskyi 2023-08-08 20:53:56 +03:00
parent d9a34af46d
commit a3536073e6
11 changed files with 40 additions and 48 deletions

View file

@ -1,6 +0,0 @@
from yt_shared.schemas.base import RealBaseModel
class BaseOrmModel(RealBaseModel):
class Config(RealBaseModel.Config):
orm_mode = True

View file

@ -3,9 +3,7 @@ from datetime import datetime
from pydantic import StrictBool, StrictFloat, StrictInt, StrictStr
from yt_shared.enums import DownMediaType, TaskSource, TaskStatus
from yt_shared.schemas.base import RealBaseModel
from api.api.api_v1.schemas.base import BaseOrmModel
from yt_shared.schemas.base import BaseOrmModel, RealBaseModel
class CacheSchema(BaseOrmModel):

View file

@ -34,7 +34,7 @@ class TaskService:
async def get_latest_task(self, include_meta: bool) -> Task:
schema = self._get_schema(include_meta)
task = await self._repository.get_latest_task(include_meta)
return schema.from_orm(task)
return schema.model_validate(task)
async def get_task(
self,
@ -43,7 +43,7 @@ class TaskService:
) -> Task:
schema = self._get_schema(include_meta)
task = await self._repository.get_task(id, include_meta)
return schema.from_orm(task)
return schema.model_validate(task)
async def get_all_tasks(
self,
@ -56,7 +56,7 @@ class TaskService:
tasks = await self._repository.get_all_tasks(
include_meta, status, limit, offset
)
return [schema.from_orm(task) for task in tasks]
return [schema.model_validate(task) for task in tasks]
@staticmethod
async def create_task_non_db(
@ -79,4 +79,4 @@ class TaskService:
return CreateTaskOut(id=task_id, url=task.url, added_at=added_at, source=source)
async def get_stats(self):
return TasksStatsSchema.from_orm(await self._repository.get_stats())
return TasksStatsSchema.model_validate(await self._repository.get_stats())

View file

@ -2,9 +2,10 @@ from pydantic import (
StrictBool,
StrictInt,
StrictStr,
constr,
validator,
StringConstraints,
field_validator,
)
from typing_extensions import Annotated
from yt_shared.enums import DownMediaType
from yt_shared.schemas.base import RealBaseModel
@ -48,10 +49,6 @@ class UserSchema(AnonymousUserSchema):
return False
def _change_type(values: list[int]) -> list[AnonymousUserSchema]:
return [AnonymousUserSchema(id=id_) for id_ in values]
class ApiSchema(RealBaseModel):
upload_video_file: StrictBool
upload_video_max_file_size: StrictInt
@ -59,14 +56,19 @@ class ApiSchema(RealBaseModel):
silent: StrictBool
video_caption: VideoCaptionSchema
_transform_chat_ids = validator('upload_to_chat_ids', pre=True)(_change_type)
@field_validator('upload_to_chat_ids', mode='before')
@classmethod
def transform_chat_ids(cls, values: list[int]) -> list[AnonymousUserSchema]:
return [AnonymousUserSchema(id=id_) for id_ in values]
class TelegramSchema(RealBaseModel):
api_id: StrictInt
api_hash: StrictStr
token: StrictStr
lang_code: constr(regex=_LANG_CODE_REGEX, to_lower=True)
lang_code: Annotated[
str, StringConstraints(pattern=_LANG_CODE_REGEX, to_lower=True)
]
max_upload_tasks: StrictInt
url_validation_regexes: list[str]
allowed_users: list[UserSchema]

View file

@ -3,5 +3,6 @@ SQLAlchemy==2.0.19
aio-pika==9.2.1
aiohttp==3.8.5
asyncpg==0.28.0
pydantic==1.10.10
pydantic==2.1.1
pydantic_settings==2.0.2
uvloop==0.17.0

View file

@ -1,7 +1,8 @@
import logging
from typing import KeysView
from pydantic import BaseSettings, validator
from pydantic import field_validator
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
@ -44,8 +45,8 @@ class Settings(BaseSettings):
TMP_DOWNLOAD_DIR: str
TMP_DOWNLOADED_DIR: str
@field_validator('LOG_LEVEL')
@classmethod
@validator('LOG_LEVEL')
def validate_log_level_value(cls, value: str) -> str:
valid_values: KeysView[str] = logging._nameToLevel.keys() # noqa
if value not in valid_values:

View file

@ -29,7 +29,7 @@ class RmqPublisher(metaclass=Singleton):
return isinstance(confirm, Basic.Ack)
async def send_for_download(self, media_payload: InbMediaPayload) -> bool:
message = aio_pika.Message(body=media_payload.json().encode())
message = aio_pika.Message(body=media_payload.model_dump_json().encode())
exchange = self._rabbit_mq.exchanges[INPUT_EXCHANGE]
confirm = await exchange.publish(
message, routing_key=INPUT_QUEUE, mandatory=True
@ -40,14 +40,14 @@ class RmqPublisher(metaclass=Singleton):
self, error_payload: ErrorDownloadPayload | ErrorGeneralPayload
) -> bool:
err_exchange = self._rabbit_mq.exchanges[ERROR_EXCHANGE]
err_message = aio_pika.Message(body=error_payload.json().encode())
err_message = aio_pika.Message(body=error_payload.model_dump_json().encode())
confirm = await err_exchange.publish(
err_message, routing_key=ERROR_QUEUE, mandatory=True
)
return self._is_sent(confirm)
async def send_download_finished(self, success_payload: SuccessPayload) -> bool:
message = aio_pika.Message(body=success_payload.json().encode())
message = aio_pika.Message(body=success_payload.model_dump_json().encode())
exchange = self._rabbit_mq.exchanges[SUCCESS_EXCHANGE]
confirm = await exchange.publish(
message, routing_key=SUCCESS_QUEUE, mandatory=True

View file

@ -1,4 +1,4 @@
from pydantic import BaseModel, Extra
from pydantic import BaseModel, ConfigDict
from yt_shared.enums import RabbitPayloadType
@ -6,14 +6,11 @@ from yt_shared.enums import RabbitPayloadType
class RealBaseModel(BaseModel):
"""Base Pydantic model. All models should inherit from this."""
class Config:
extra = Extra.forbid
model_config = ConfigDict(extra='forbid')
def json(self, *args, **kwargs) -> str:
"""By default, dump without whitespaces."""
if 'separators' not in kwargs:
kwargs['separators'] = (',', ':')
return super().json(*args, **kwargs)
class BaseOrmModel(RealBaseModel):
model_config = ConfigDict(from_attributes=True, **RealBaseModel.model_config)
class BaseRabbitPayloadModel(RealBaseModel):

View file

@ -8,7 +8,7 @@ from pydantic import (
StrictFloat,
StrictInt,
StrictStr,
root_validator,
model_validator,
)
from yt_shared.enums import DownMediaType, MediaFileType, TaskSource, TelegramChatType
@ -19,7 +19,7 @@ from yt_shared.utils.common import format_bytes
class InbMediaPayload(RealBaseModel):
"""RabbitMQ incoming media payload from Telegram Bot or API service."""
id: uuid.UUID | None
id: uuid.UUID | None = None
from_chat_id: StrictInt | None
from_chat_type: TelegramChatType | None
from_user_id: StrictInt | None
@ -68,7 +68,8 @@ class Video(BaseMedia):
height: int | None = None
thumb_path: StrictStr | None = None
@root_validator(pre=False)
@model_validator(mode='before')
@classmethod
def _set_fields(cls, values: dict) -> dict:
if not values['thumb_name']:
values['thumb_name'] = f'{values["filename"]}-thumb.jpg'
@ -85,7 +86,8 @@ class DownMedia(RealBaseModel):
root_path: StrictStr
meta: dict
@root_validator(pre=True)
@model_validator(mode='before')
@classmethod
def _validate(cls, values: dict) -> dict:
if values['audio'] is None and values['video'] is None:
raise ValueError('Provide audio, video or both.')

View file

@ -1,8 +1,8 @@
import datetime
from pydantic import Field, StrictStr, validator
from pydantic import Field, StrictStr, field_validator
from yt_shared.schemas.base import RealBaseModel
from yt_shared.schemas.base import BaseOrmModel, RealBaseModel
def _remove_microseconds(dt_obj: datetime.datetime) -> datetime.datetime:
@ -13,24 +13,21 @@ class LatestVersion(RealBaseModel):
version: StrictStr
retrieved_at: datetime.datetime
@field_validator('retrieved_at', mode='before')
@classmethod
@validator('retrieved_at', pre=True)
def remove_microseconds(cls, value: datetime.datetime) -> datetime.datetime:
return _remove_microseconds(value)
class CurrentVersion(RealBaseModel):
class CurrentVersion(BaseOrmModel):
version: StrictStr = Field(..., alias='current_version')
updated_at: datetime.datetime
@field_validator('updated_at', mode='before')
@classmethod
@validator('updated_at', pre=True)
def remove_microseconds(cls, value: datetime.datetime) -> datetime.datetime:
return _remove_microseconds(value)
class Config(RealBaseModel.Config):
orm_mode = True
class VersionContext(RealBaseModel):
latest: LatestVersion

View file

@ -35,4 +35,4 @@ class YtdlpVersionChecker:
async def get_current_version(self, db: AsyncSession) -> CurrentVersion:
ytdlp_ = await self._ytdlp_repository.get_current_version(db)
self._log.info('Current yt-dlp version: %s', ytdlp_.current_version)
return CurrentVersion.from_orm(ytdlp_)
return CurrentVersion.model_validate(ytdlp_)