refine code

This commit is contained in:
Benny 2023-12-17 16:29:37 +01:00
parent 62342ee350
commit 1c3becee9d
No known key found for this signature in database
GPG key ID: 6CD0DBDA5235D481
4 changed files with 60 additions and 58 deletions

View file

@ -9,8 +9,6 @@ ENV TZ=Europe/London
RUN apt update && apt install -y --no-install-recommends --no-install-suggests ffmpeg vnstat git aria2
COPY --from=builder /root/.local /usr/local
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo
COPY . /ytdlbot
CMD ["/usr/local/bin/supervisord", "-c" ,"/ytdlbot/conf/supervisor_main.conf"]

View file

@ -13,7 +13,7 @@ from blinker import signal
# general settings
WORKERS: int = int(os.getenv("WORKERS", 10))
PYRO_WORKERS: int = int(os.getenv("PYRO_WORKERS", min(32, (os.cpu_count() or 0) + 4)))
PYRO_WORKERS: int = int(os.getenv("PYRO_WORKERS", min(32, (os.cpu_count() or 0) + 10)))
APP_ID: int = int(os.getenv("APP_ID", 198214))
APP_HASH = os.getenv("APP_HASH", "1234b90")
TOKEN = os.getenv("TOKEN", "1234")

View file

@ -67,7 +67,8 @@ bot = create_app("tasks")
channel = Channel()
def get_messages(chat_id: int, message_id: int):
def retrieve_message(chat_id: int, message_id: int) -> types.Message | Any:
# this should only be called by celery tasks
try:
return bot.get_messages(chat_id, message_id)
except ConnectionError as e:
@ -79,19 +80,27 @@ def get_messages(chat_id: int, message_id: int):
@app.task(rate_limit=f"{RATE_LIMIT}/m")
def ytdl_download_task(chat_id: int, message_id: int, url: str):
logging.info("YouTube celery tasks started for %s", url)
bot_msg = get_messages(chat_id, message_id)
ytdl_normal_download(bot_msg, url)
bot_msg = retrieve_message(chat_id, message_id)
ytdl_normal_download(bot, bot_msg, url)
logging.info("YouTube celery tasks ended.")
@app.task()
def audio_task(chat_id: int, message_id: int):
logging.info("Audio celery tasks started for %s-%s", chat_id, message_id)
bot_msg = get_messages(chat_id, message_id)
normal_audio(bot_msg)
bot_msg = retrieve_message(chat_id, message_id)
normal_audio(bot, bot_msg)
logging.info("Audio celery tasks ended.")
@app.task()
def direct_download_task(chat_id: int, message_id: int, url: str):
logging.info("Direct download celery tasks started for %s", url)
bot_msg = retrieve_message(chat_id, message_id)
direct_normal_download(bot, bot_msg, url)
logging.info("Direct download celery tasks ended.")
def get_unique_clink(original_url: str, user_id: int):
payment = Payment()
settings = payment.get_user_settings(user_id)
@ -104,45 +113,36 @@ def get_unique_clink(original_url: str, user_id: int):
return unique
@app.task()
def direct_download_task(chat_id: int, message_id: int, url: str):
logging.info("Direct download celery tasks started for %s", url)
bot_msg = get_messages(chat_id, message_id)
direct_normal_download(bot, bot_msg, url)
logging.info("Direct download celery tasks ended.")
def forward_video(bot_msg: types.Message, url: str):
redis = Redis()
chat_id = bot_msg.chat.id
unique = get_unique_clink(url, chat_id)
cached_fid = redis.get_send_cache(unique)
if not cached_fid:
redis.update_metrics("cache_miss")
return False
res_msg = upload_processor(bot_msg, url, cached_fid)
def forward_video(client, bot_msg: types.Message | Any, url: str, cached_fid: str):
res_msg = upload_processor(client, bot_msg, url, cached_fid)
obj = res_msg.document or res_msg.video or res_msg.audio or res_msg.animation or res_msg.photo
caption, _ = gen_cap(bot_msg, url, obj)
res_msg.edit_text(caption, reply_markup=gen_video_markup())
bot_msg.edit_text(f"Download success!✅")
redis.update_metrics("cache_hit")
return True
def ytdl_download_entrance(bot_msg: types.Message, url: str, mode=None):
def ytdl_download_entrance(client: Client, bot_msg: types.Message, url: str, mode=None):
# in Local node and forward mode, we pass client from main
# in celery mode, we need to use our own client called bot
payment = Payment()
redis = Redis()
chat_id = bot_msg.chat.id
unique = get_unique_clink(url, chat_id)
cached_fid = redis.get_send_cache(unique)
try:
if forward_video(bot_msg, url):
if cached_fid:
forward_video(client, bot_msg, url, cached_fid)
redis.update_metrics("cache_hit")
return
redis.update_metrics("cache_miss")
mode = mode or payment.get_user_settings(chat_id)[-1]
if ENABLE_CELERY and mode in [None, "Celery"]:
ytdl_download_task.delay(chat_id, bot_msg.id, url)
else:
get_messages(chat_id, bot_msg.id)
ytdl_normal_download(bot_msg, url)
ytdl_normal_download(client, bot_msg, url)
except Exception as e:
logging.error("Failed to download %s, error: %s", url, e)
bot_msg.edit_text(f"Download failed!❌\n\n`{traceback.format_exc()[0:4000]}`", disable_web_page_preview=True)
@ -156,11 +156,11 @@ def direct_download_entrance(client: Client, bot_msg: typing.Union[types.Message
direct_normal_download(client, bot_msg, url)
def audio_entrance(bot_msg: types.Message):
def audio_entrance(client: Client, bot_msg: types.Message):
if ENABLE_CELERY:
audio_task.delay(bot_msg.chat.id, bot_msg.id)
else:
normal_audio(bot_msg)
normal_audio(client, bot_msg)
def direct_normal_download(client: Client, bot_msg: typing.Union[types.Message, typing.Coroutine], url: str):
@ -208,7 +208,7 @@ def direct_normal_download(client: Client, bot_msg: typing.Union[types.Message,
bot_msg.edit_text("Download success!✅")
def normal_audio(bot_msg: typing.Union[types.Message, typing.Coroutine]):
def normal_audio(client: Client, bot_msg: typing.Union[types.Message, typing.Coroutine]):
chat_id = bot_msg.chat.id
# fn = getattr(bot_msg.video, "file_name", None) or getattr(bot_msg.document, "file_name", None)
status_msg: typing.Union[types.Message, typing.Coroutine] = bot_msg.reply_text(
@ -216,36 +216,42 @@ def normal_audio(bot_msg: typing.Union[types.Message, typing.Coroutine]):
)
orig_url: str = re.findall(r"https?://.*", bot_msg.caption)[0]
with tempfile.TemporaryDirectory(prefix="ytdl-", dir=TMPFILE_PATH) as tmp:
bot.send_chat_action(chat_id, enums.ChatAction.RECORD_AUDIO)
client.send_chat_action(chat_id, enums.ChatAction.RECORD_AUDIO)
# just try to download the audio using yt-dlp
filepath = ytdl_download(orig_url, tmp, status_msg, hijack="bestaudio[ext=m4a]")
status_msg.edit_text("Sending audio now...")
bot.send_chat_action(chat_id, enums.ChatAction.UPLOAD_AUDIO)
client.send_chat_action(chat_id, enums.ChatAction.UPLOAD_AUDIO)
for f in filepath:
bot.send_audio(chat_id, f)
client.send_audio(chat_id, f)
status_msg.edit_text("✅ Conversion complete.")
Redis().update_metrics("audio_success")
def ytdl_normal_download(bot_msg: types.Message | typing.Any, url: str):
def ytdl_normal_download(client: Client, bot_msg: types.Message | typing.Any, url: str):
"""
This function is called by celery task or directly by bot
:param client: bot client, either from main or bot(celery)
:param bot_msg: bot message
:param url: url to download
"""
chat_id = bot_msg.chat.id
temp_dir = tempfile.TemporaryDirectory(prefix="ytdl-", dir=TMPFILE_PATH)
video_paths = ytdl_download(url, temp_dir.name, bot_msg)
logging.info("Download complete.")
bot.send_chat_action(chat_id, enums.ChatAction.UPLOAD_DOCUMENT)
client.send_chat_action(chat_id, enums.ChatAction.UPLOAD_DOCUMENT)
bot_msg.edit_text("Download complete. Sending now...")
try:
upload_processor(bot_msg, url, video_paths)
upload_processor(client, bot_msg, url, video_paths)
except pyrogram.errors.Flood as e:
logging.critical("FloodWait from Telegram: %s", e)
bot.send_message(
client.send_message(
chat_id,
f"I'm being rate limited by Telegram. Your video will come after {e} seconds. Please wait patiently.",
)
bot.send_message(OWNER, f"CRITICAL INFO: {e}")
client.send_message(OWNER, f"CRITICAL INFO: {e}")
time.sleep(e.value)
upload_processor(bot_msg, url, video_paths)
upload_processor(client, bot_msg, url, video_paths)
bot_msg.edit_text("Download success!✅")
@ -274,7 +280,7 @@ def generate_input_media(file_paths: list, cap: str) -> list:
return input_media
def upload_processor(bot_msg: types.Message, url: str, vp_or_fid: str | list):
def upload_processor(client: Client, bot_msg: types.Message, url: str, vp_or_fid: str | list):
redis = Redis()
# raise pyrogram.errors.exceptions.FloodWait(13)
# if is str, it's a file id; else it's a list of paths
@ -284,7 +290,7 @@ def upload_processor(bot_msg: types.Message, url: str, vp_or_fid: str | list):
if isinstance(vp_or_fid, list) and len(vp_or_fid) > 1:
# just generate the first for simplicity, send as media group(2-20)
cap, meta = gen_cap(bot_msg, url, vp_or_fid[0])
res_msg: list["types.Message"] | Any = bot.send_media_group(chat_id, generate_input_media(vp_or_fid, cap))
res_msg: list["types.Message"] | Any = client.send_media_group(chat_id, generate_input_media(vp_or_fid, cap))
# TODO no cache for now
return res_msg[0]
elif isinstance(vp_or_fid, list) and len(vp_or_fid) == 1:
@ -303,7 +309,7 @@ def upload_processor(bot_msg: types.Message, url: str, vp_or_fid: str | list):
logging.info("Sending as document")
try:
# send as document could be sent as video even if it's a document
res_msg = bot.send_document(
res_msg = client.send_document(
chat_id,
vp_or_fid,
caption=cap,
@ -315,7 +321,7 @@ def upload_processor(bot_msg: types.Message, url: str, vp_or_fid: str | list):
)
except ValueError:
logging.error("Retry to send as video")
res_msg = bot.send_video(
res_msg = client.send_video(
chat_id,
vp_or_fid,
supports_streaming=True,
@ -327,7 +333,7 @@ def upload_processor(bot_msg: types.Message, url: str, vp_or_fid: str | list):
)
elif settings[2] == "audio":
logging.info("Sending as audio")
res_msg = bot.send_audio(
res_msg = client.send_audio(
chat_id,
vp_or_fid,
caption=cap,
@ -338,7 +344,7 @@ def upload_processor(bot_msg: types.Message, url: str, vp_or_fid: str | list):
# settings==video
logging.info("Sending as video")
try:
res_msg = bot.send_video(
res_msg = client.send_video(
chat_id,
vp_or_fid,
supports_streaming=True,
@ -352,7 +358,7 @@ def upload_processor(bot_msg: types.Message, url: str, vp_or_fid: str | list):
# try to send as annimation, photo
try:
logging.warning("Retry to send as animation")
res_msg = bot.send_animation(
res_msg = client.send_animation(
chat_id,
vp_or_fid,
caption=cap,
@ -364,7 +370,7 @@ def upload_processor(bot_msg: types.Message, url: str, vp_or_fid: str | list):
except Exception:
# this is likely a photo
logging.warning("Retry to send as photo")
res_msg = bot.send_photo(
res_msg = client.send_photo(
chat_id,
vp_or_fid,
caption=cap,
@ -377,7 +383,7 @@ def upload_processor(bot_msg: types.Message, url: str, vp_or_fid: str | list):
redis.add_send_cache(unique, getattr(obj, "file_id", None))
redis.update_metrics("video_success")
if ARCHIVE_ID and isinstance(vp_or_fid, pathlib.Path):
bot.forward_messages(bot_msg.chat.id, ARCHIVE_ID, res_msg.id)
client.forward_messages(bot_msg.chat.id, ARCHIVE_ID, res_msg.id)
return res_msg
@ -473,9 +479,8 @@ def run_celery():
if __name__ == "__main__":
bot.start()
print("Bootstrapping Celery worker now.....")
time.sleep(3)
time.sleep(5)
threading.Thread(target=run_celery, daemon=True).start()
scheduler = BackgroundScheduler(timezone="Europe/London")

View file

@ -451,7 +451,7 @@ def download_handler(client: Client, message: types.Message):
client.send_chat_action(chat_id, enums.ChatAction.UPLOAD_VIDEO)
bot_msg.chat = message.chat
ytdl_download_entrance(bot_msg, url)
ytdl_download_entrance(client, bot_msg, url)
@app.on_callback_query(filters.regex(r"document|video|audio"))
@ -482,7 +482,7 @@ def audio_callback(client: Client, callback_query: types.CallbackQuery):
callback_query.answer(f"Converting to audio...please wait patiently")
redis.update_metrics("audio_request")
audio_entrance(callback_query.message)
audio_entrance(client, callback_query.message)
@app.on_callback_query(filters.regex(r"Local|Celery"))
@ -500,8 +500,7 @@ def periodic_sub_check():
logging.info(f"periodic update:{video_url} - {uids}")
for uid in uids:
try:
bot_msg: types.Message | Any = app.send_message(uid, f"{video_url} is out. Watch it on YouTube")
# ytdl_download_entrance(app, bot_msg, video_url, mode="direct")
app.send_message(uid, f"{video_url} is out. Watch it on YouTube")
except (exceptions.bad_request_400.PeerIdInvalid, exceptions.bad_request_400.UserIsBlocked) as e:
logging.warning("User is blocked or deleted. %s", e)
channel.deactivate_user_subscription(uid)