new update

This commit is contained in:
Taras Terletskyi 2022-11-04 18:45:33 +02:00
parent a2e9684044
commit 20d8d55a66
25 changed files with 136 additions and 91 deletions

View file

@ -10,4 +10,7 @@ pyproject.toml
api/Dockerfile
bot/Dockerfile
bot/config-example.yml
bot/*.session
bot/*.session-journal
worker/Dockerfile

View file

@ -6,12 +6,12 @@ Simple and reliable YouTube Download Telegram Bot.
## 😂 Features
* Download videos from any [yt-dlp](https://github.com/yt-dlp/yt-dlp) supported website
* Upload downloaded videos to the Telegram chat
* Additionally, trigger video download by sending link to API call
* Track download tasks in the database or API
* Trigger video download by sending link to an API
* Track download tasks via API
## ⚙ Quick Setup
1. Create Telegram bot using [BotFather](https://t.me/BotFather) and get your `token`
2. [Get own Telegram API key](https://my.telegram.org/apps) (`api_id` and `api_hash`)
2. [Get your Telegram API key](https://my.telegram.org/apps) (`api_id` and `api_hash`)
3. [Find your Telegram User ID](https://stackoverflow.com/questions/32683992/find-out-my-own-user-id-for-sending-a-message-with-telegram-api)
4. Copy `bot/config-example.yml` to `bot/config.yml`
5. Write `token`, `api_id`, `api_hash` to `bot/config.yml` by changing respective placeholders
@ -37,6 +37,9 @@ docker compose build base-image
# Build and run all services in detached mode
docker compose up --build -d -t 0 && docker compose logs --tail 100 -f
# Stop all services
docker compose stop -t 0
```
Your telegram bot should send you a startup message:
@ -58,7 +61,6 @@ If your URL can't be downloaded for some reason, you will see this
- **API**: default port `1984` and no auth. Port can be changed in `docker-compose.yml`
- **RabbitMQ**: default creds are located in `envs/.env_common`
- **PostgreSQL**: default creds are located in `envs/.env_common`. Same creds are stored for Alembic in `alembic.ini`.
- **PGAdmin**: default creds are located in `docker-compose.yml`
## API
By default, API service will run on your `localhost` and `1984` port.

View file

@ -3,7 +3,7 @@ import logging
from core.config import settings
def setup_logging():
def setup_logging() -> None:
log_format = '%(asctime)s - [%(levelname)s] - [%(name)s:%(lineno)s] - %(message)s'
logging.basicConfig(
format=log_format, level=logging.getLevelName(settings.LOG_LEVEL)

View file

@ -3,6 +3,7 @@ telegram:
api_hash: "<PASTE-HERE-TELEGRAM-API-HASH>"
token: "<PASTE-HERE-TELEGRAM-TOKEN>"
lang_code: !!str "en"
max_upload_tasks: 5
allowed_users:
- id: 00000000000
upload:
@ -11,6 +12,10 @@ telegram:
forward_to_group: !!bool False
forward_group_id: -00000000000
silent: !!bool False
video_caption:
include_title: !!bool True
include_filename: !!bool False
include_link: !!bool True
api:
upload_vide_file: !!bool False
upload_video_max_file_size: 2147483648
@ -18,4 +23,8 @@ telegram:
- 00000000000
- -00000000000
silent: !!bool False
ytdlp_version_check_interval: 3600
video_caption:
include_title: !!bool True
include_filename: !!bool False
include_link: !!bool True
ytdlp_version_check_interval: 86400

View file

@ -26,9 +26,10 @@ class VideoBot(Client):
user.id: user for user in self.conf.telegram.allowed_users
}
@staticmethod
async def run_forever() -> None:
async def run_forever(self) -> None:
"""Firstly 'await bot.start()' should be called."""
if not self.is_initialized:
raise RuntimeError('Bot was not started (initialized).')
while True:
await asyncio.sleep(86400)

View file

@ -31,7 +31,7 @@ class BotLauncher:
self._setup_handlers()
await self._start_bot()
def _setup_handlers(self):
def _setup_handlers(self) -> None:
cb = TelegramCallback()
self._bot.add_handler(
MessageHandler(
@ -59,8 +59,8 @@ class BotLauncher:
exception_message_args=(task_name,),
)
async def _start_tasks(self):
await self._rabbit_worker_manager.start_worker_tasks()
async def _start_tasks(self) -> None:
await self._rabbit_worker_manager.start_workers()
task_name = YtdlpNewVersionNotifyTask.__class__.__name__
create_task(

View file

@ -6,13 +6,13 @@ from pyrogram.types import Message
from core.bot import VideoBot
from core.service import URLService
from core.utils import bold
from yt_shared.enums import TelegramChatType
from yt_shared.emoji import SUCCESS_EMOJI
from yt_shared.enums import TelegramChatType
from yt_shared.schemas.url import URL
class TelegramCallback:
_MSG_SEND_OK = f'{SUCCESS_EMOJI} {bold("URL sent for download")}'
_MSG_SEND_OK = f'{SUCCESS_EMOJI} {bold("{count}URL{plural} sent for download")}'
_MSG_SEND_FAIL = f'🛑 {bold("Failed to send URL for download")}'
def __init__(self) -> None:
@ -32,8 +32,18 @@ class TelegramCallback:
self._log.debug(message)
urls = self._get_urls(message)
await self._url_service.process_urls(urls=urls)
await self._send_acknowledge_message(message=message, urls=urls)
async def _send_acknowledge_message(
self, message: Message, urls: list[URL]
) -> None:
urls_count = len(urls)
is_multiple = urls_count > 1
await message.reply(
self._MSG_SEND_OK,
self._MSG_SEND_OK.format(
count=f'{urls_count} ' if is_multiple else '',
plural='s' if is_multiple else '',
),
parse_mode=ParseMode.HTML,
reply_to_message_id=message.id,
)

View file

@ -16,23 +16,30 @@ class BaseUserSchema(RealBaseModel):
id: StrictInt
@property
def is_base_user(self):
def is_base_user(self) -> bool:
return True
class VideoCaptionSchema(RealBaseModel):
include_title: StrictBool
include_filename: StrictBool
include_link: StrictBool
class UserUploadSchema(RealBaseModel):
upload_vide_file: StrictBool
upload_video_max_file_size: StrictInt
forward_to_group: StrictBool
forward_group_id: StrictInt | None
silent: StrictBool
video_caption: VideoCaptionSchema
class UserSchema(BaseUserSchema):
upload: UserUploadSchema
@property
def is_base_user(self):
def is_base_user(self) -> bool:
return False
@ -45,6 +52,7 @@ class ApiSchema(RealBaseModel):
upload_video_max_file_size: StrictInt
upload_to_chat_ids: list[BaseUserSchema]
silent: StrictBool
video_caption: VideoCaptionSchema
_transform_chat_ids = validator('upload_to_chat_ids', pre=True)(change_type)
@ -54,6 +62,7 @@ class TelegramSchema(RealBaseModel):
api_hash: StrictStr
token: StrictStr
lang_code: constr(regex=_LANG_CODE_REGEX, to_lower=True)
max_upload_tasks: StrictInt
allowed_users: list[UserSchema]
api: ApiSchema

View file

@ -39,12 +39,14 @@ class SuccessHandler(AbstractHandler):
async def _create_upload_task(self) -> None:
"""Upload video to Telegram chat."""
semaphore = asyncio.Semaphore(value=self._bot.conf.telegram.max_upload_tasks)
task_name = UploadTask.__class__.__name__
await create_task(
UploadTask(
body=self._body,
users=self._receiving_users,
bot=self._bot,
semaphore=semaphore,
).run(),
task_name=task_name,
logger=self._log,

View file

@ -41,6 +41,7 @@ class UploadTask(AbstractTask):
body: SuccessPayload,
users: list[BaseUserSchema | UserSchema],
bot: 'VideoBot',
semaphore: asyncio.Semaphore,
) -> None:
super().__init__()
self._config = get_main_config()
@ -50,15 +51,22 @@ class UploadTask(AbstractTask):
self.thumb_path = os.path.join(settings.TMP_DOWNLOAD_PATH, body.thumb_name)
self._bot = bot
self._users = users
self._semaphore = semaphore
self._upload_chat_ids, self._forward_chat_ids = self._get_upload_chat_ids()
self._video_ctx = self._create_video_context()
self._cached_message: Message | None = None
async def run(self) -> None:
async with self._semaphore:
self._log.debug('Semaphore for "%s" acquired', self.filename)
await self._run()
self._log.debug('Semaphore for "%s" released', self.filename)
async def _run(self) -> None:
try:
await self._send_upload_text()
await self._upload_video_file()
await asyncio.gather(*(self._send_upload_text(), self._upload_video_file()))
except Exception:
self._log.exception('Exception in upload task for "%s"', self.filename)
@ -163,10 +171,19 @@ class UploadTask(AbstractTask):
)
def _generate_video_caption(self) -> str:
return (
f'{bold("Title:")} {self._body.title}\n'
f'{bold("URL:")} {self._body.context.url}'
)
caption_items = []
if self._users[0].is_base_user:
caption_conf = self._bot.conf.telegram.api.video_caption
else:
caption_conf = self._users[0].upload.video_caption
if caption_conf.include_title:
caption_items.append(f'{bold("Title:")} {self._body.title}')
if caption_conf.include_filename:
caption_items.append(f'{bold("Filename:")} {self._body.filename}')
if caption_conf.include_link:
caption_items.append(f'{bold("URL:")} {self._body.context.url}')
return '\n'.join(caption_items)
async def _save_cache_to_db(self, video: Video | Animation) -> None:
cache = CacheSchema(

View file

@ -24,6 +24,9 @@ class YtdlpNewVersionNotifyTask(AbstractTask):
self._check_interval = get_main_config().ytdlp_version_check_interval
async def run(self) -> None:
await self._run()
async def _run(self) -> None:
while True:
self._log.info('Checking for new yt-dlp version')
try:
@ -38,7 +41,8 @@ class YtdlpNewVersionNotifyTask(AbstractTask):
def _get_next_check_datetime(self) -> datetime.datetime:
return (
datetime.datetime.now() + datetime.timedelta(seconds=self._check_interval)
datetime.datetime.now(datetime.timezone.utc)
+ datetime.timedelta(seconds=self._check_interval)
).replace(microsecond=0)
async def _notify_if_new_version(self) -> None:
@ -49,12 +53,15 @@ class YtdlpNewVersionNotifyTask(AbstractTask):
elif not self._startup_message_sent:
await self._notify_up_to_date(context)
self._startup_message_sent = True
else:
# Explicitly do nothing.
pass
async def _notify_outdated(self, ctx: VersionContext) -> None:
text = (
f'New {code("yt-dlp")} version available: {bold(ctx.latest.version)}\n'
f'Current version: {bold(ctx.current.version)}\n'
f'Rebuild worker with {code("docker-compose build --no-cache worker")}'
f'Rebuild worker with {code("docker compose build --no-cache worker")}'
)
await self._send_to_chat(text)

View file

@ -15,13 +15,13 @@ if TYPE_CHECKING:
from core.bot import VideoBot
class RabbitTaskType(enum.Enum):
class RabbitWorkerType(enum.Enum):
ERROR = 'ERROR'
SUCCESS = 'SUCCESS'
class AbstractResultWorker(AbstractTask):
TYPE: RabbitTaskType | None = None
TYPE: RabbitWorkerType | None = None
QUEUE_TYPE: str | None = None
SCHEMA_CLS: tuple[Type[BaseModel]] = ()

View file

@ -1,11 +1,11 @@
from core.handlers.error import ErrorHandler
from core.workers.abstract import AbstractResultWorker, RabbitTaskType
from core.workers.abstract import AbstractResultWorker, RabbitWorkerType
from yt_shared.rabbit.rabbit_config import ERROR_QUEUE
from yt_shared.schemas.error import ErrorDownloadPayload, ErrorGeneralPayload
class ErrorResultWorker(AbstractResultWorker):
TYPE = RabbitTaskType.ERROR
TYPE = RabbitWorkerType.ERROR
QUEUE_TYPE = ERROR_QUEUE
SCHEMA_CLS = (ErrorDownloadPayload, ErrorGeneralPayload)
HANDLER_CLS = ErrorHandler

View file

@ -2,7 +2,7 @@ import logging
from asyncio import Task
from typing import TYPE_CHECKING
from core.workers.abstract import RabbitTaskType
from core.workers.abstract import RabbitWorkerType
from core.workers.error import ErrorResultWorker
from core.workers.success import SuccessResultWorker
from yt_shared.utils.tasks.tasks import create_task
@ -17,21 +17,25 @@ class RabbitWorkerManager:
def __init__(self, bot: 'VideoBot') -> None:
self._log = logging.getLogger(self.__class__.__name__)
self._bot = bot
self._tasks: dict[RabbitTaskType, Task] = {}
self._workers: dict[RabbitWorkerType, Task] = {}
async def start_worker_tasks(self) -> None:
async def start_workers(self) -> None:
"""Start background result workers.
Workers watch RabbitMQ queues and dispatch payload to the appropriate handlers.
"""
for task_cls in self._TASK_TYPES:
self._log.info('Starting %s', task_cls.__name__)
self._tasks[task_cls.TYPE] = create_task(
self._workers[task_cls.TYPE] = create_task(
task_cls(self._bot).run(),
task_name=task_cls.__name__,
logger=self._log,
exception_message='Rabbit task %s raised an exception',
exception_message='RabbitMQ worker %s raised an exception',
exception_message_args=(task_cls.__name__,),
)
def stop_workers(self):
self._log.info('Stopping %d rabbit tasks', len(self._tasks))
for task_type, task in self._tasks.items():
self._log.info('Stopping %s rabbit task', task_type.value)
task.cancel()
def stop_workers(self) -> None:
self._log.info('Stopping %d RabbitMQ workers', len(self._workers))
for worker_type, worker_task in self._workers.items():
self._log.info('Stopping %s RabbitMQ worker', worker_type.value)
worker_task.cancel()

View file

@ -1,14 +1,25 @@
from core.handlers.success import SuccessHandler
from core.workers.abstract import AbstractResultWorker, RabbitTaskType
from core.workers.abstract import AbstractResultWorker, RabbitWorkerType
from yt_shared.rabbit.rabbit_config import SUCCESS_QUEUE
from yt_shared.schemas.success import SuccessPayload
from yt_shared.utils.tasks.tasks import create_task
class SuccessResultWorker(AbstractResultWorker):
TYPE = RabbitTaskType.SUCCESS
TYPE = RabbitWorkerType.SUCCESS
QUEUE_TYPE = SUCCESS_QUEUE
SCHEMA_CLS = (SuccessPayload,)
HANDLER_CLS = SuccessHandler
async def _process_body(self, body: SuccessPayload) -> None:
await self.HANDLER_CLS(body=body, bot=self._bot).handle()
self._spawn_handler_task(body)
def _spawn_handler_task(self, body: SuccessPayload) -> None:
task_name = self.HANDLER_CLS.__class__.__name__
create_task(
self.HANDLER_CLS(body=body, bot=self._bot).handle(),
task_name=task_name,
logger=self._log,
exception_message='Task %s raised an exception',
exception_message_args=(task_name,),
)

View file

@ -81,29 +81,9 @@ services:
image: "redis:alpine"
container_name: yt_redis
restart: unless-stopped
pgadmin:
container_name: yt_pgadmin
image: dpage/pgadmin4:6
environment:
PGADMIN_DEFAULT_EMAIL: "admin@admin.com"
PGADMIN_DEFAULT_PASSWORD: "123"
PGADMIN_CONFIG_SERVER_MODE: "False"
volumes:
- pgadmin:/root/.pgadmin
ports:
- "5050:80"
restart: unless-stopped
deploy:
resources:
limits:
cpus: "0.5"
memory: 1G
depends_on:
- postgres
volumes:
pgdata:
pgadmin:
shared-tmpfs:
driver: local
driver_opts:

View file

@ -18,5 +18,5 @@ RABBITMQ_PORT=5672
REDIS_HOST=yt_redis
LOG_LEVEL=INFO
LOG_LEVEL=DEBUG
TMP_DOWNLOAD_PATH=/tmp/download_tmpfs

View file

@ -16,7 +16,7 @@ class WorkerLauncher:
self._log = logging.getLogger(self.__class__.__name__)
self._rabbit_mq = get_rabbitmq()
def start(self):
def start(self) -> None:
self._log.info('Starting download worker instance')
loop = asyncio.get_event_loop()
loop.run_until_complete(self._start())
@ -29,7 +29,7 @@ class WorkerLauncher:
await self._perform_setup()
async def _perform_setup(self) -> None:
await asyncio.gather(self._setup_rabbit(), self._set_yt_dlp_version())
await asyncio.gather(*(self._setup_rabbit(), self._set_yt_dlp_version()))
async def _setup_rabbit(self) -> None:
self._log.info('Setting up RabbitMQ connection')
@ -39,7 +39,7 @@ class WorkerLauncher:
)
await self._rabbit_mq.queues[INPUT_QUEUE].consume(cb.on_input_message)
async def _set_yt_dlp_version(self):
async def _set_yt_dlp_version(self) -> None:
curr_version = ytdlp_version.__version__
self._log.info('Setting current yt-dlp version (%s)', curr_version)
async for db in get_db():

View file

@ -3,7 +3,7 @@ import logging
from core.config import settings
def setup_logging():
def setup_logging() -> None:
log_format = '%(asctime)s - [%(levelname)s] - [%(name)s:%(lineno)s] - %(message)s'
logging.basicConfig(
format=log_format, level=logging.getLevelName(settings.LOG_LEVEL)

View file

@ -78,18 +78,10 @@ class VideoService:
tasks = [self._create_thumbnail_task(file_path, thumb_path, video.duration)]
if settings.SAVE_VIDEO_FILE:
tasks.append(self._create_copy_file_task(video))
results = await asyncio.gather(*tasks, return_exceptions=True)
self._raise_on_exception(results)
await asyncio.gather(*tasks)
final_coros = [self._repository.save_as_done(db, task, video)]
results = await asyncio.gather(*final_coros, return_exceptions=True)
self._raise_on_exception(results)
@staticmethod
def _raise_on_exception(results: tuple) -> None:
for result in results:
if isinstance(result, Exception):
raise result
await asyncio.gather(*final_coros)
@staticmethod
async def _set_probe_ctx(file_path: str, video: DownVideo) -> None:

View file

@ -23,11 +23,6 @@ setup(
'Development Status :: 2 - Pre-Alpha',
'Intended Audience :: Developers',
'Natural Language :: English',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
],
description='Common shared utils for yt downloader bot',

View file

@ -42,8 +42,9 @@ class Settings(BaseSettings):
TMP_DOWNLOAD_PATH: str
@classmethod
@validator('LOG_LEVEL')
def validate_log_level_value(cls, value):
def validate_log_level_value(cls, value: str) -> str:
valid_values: KeysView[str] = logging._nameToLevel.keys() # noqa
if value not in valid_values:
raise ValueError(f'"LOG_LEVEL" must be one of {valid_values}')

View file

@ -27,23 +27,23 @@ class RabbitMQ:
await self._set_exchanges()
await self._set_queues()
async def _set_connection(self):
async def _set_connection(self) -> None:
self.connection = await aio_pika.connect_robust(
settings.RABBITMQ_URI,
loop=get_running_loop(),
reconnect_interval=self.RABBITMQ_RECONNECT_INTERVAL,
)
async def _set_channel(self):
async def _set_channel(self) -> None:
self.channel = await self.connection.channel()
await self.channel.set_qos(prefetch_count=self.MAX_UNACK_MESSAGES_PER_CHANNEL)
async def _set_exchanges(self):
async def _set_exchanges(self) -> None:
for exchange_data in self._config.get('exchanges', []):
exchange = await self.channel.declare_exchange(**exchange_data)
self.exchanges[exchange_data['name']] = exchange
async def _set_queues(self):
async def _set_queues(self) -> None:
for queue_data in self._config.get('queues', []):
queue = await self.channel.declare_queue(**queue_data)
queue_name = queue_data['name']

View file

@ -14,6 +14,7 @@ class BaseRabbitPayloadModel(RealBaseModel):
type: RabbitPayloadType = None
@classmethod
@validator('type')
def validate_type_value(cls, v: RabbitPayloadType) -> RabbitPayloadType:
if v is not cls._TYPE:

View file

@ -13,6 +13,7 @@ class LatestVersion(RealBaseModel):
version: StrictStr
retrieved_at: datetime.datetime
@classmethod
@validator('retrieved_at', pre=True)
def remove_microseconds(cls, value: datetime.datetime) -> datetime.datetime:
return _remove_microseconds(value)
@ -22,6 +23,7 @@ class CurrentVersion(RealBaseModel):
version: StrictStr = Field(..., alias='current_version')
updated_at: datetime.datetime
@classmethod
@validator('updated_at', pre=True)
def remove_microseconds(cls, value: datetime.datetime) -> datetime.datetime:
return _remove_microseconds(value)
@ -33,10 +35,9 @@ class CurrentVersion(RealBaseModel):
class VersionContext(RealBaseModel):
latest: LatestVersion
current: CurrentVersion
has_new_version: bool | None
@validator('has_new_version', always=True)
def check_new_version(cls, value, values: dict) -> bool:
return [int(x) for x in values['latest'].version.split('.')] > [
int(x) for x in values['current'].version.split('.')
@property
def has_new_version(self) -> bool:
return [int(x) for x in self.latest.version.split('.')] > [
int(x) for x in self.current.version.split('.')
]