mirror of
https://github.com/Dineshkarthik/telegram_media_downloader.git
synced 2024-12-28 17:51:03 +08:00
636 lines
18 KiB
Python
636 lines
18 KiB
Python
"""Unittest module for media downloader."""
|
|
import os
|
|
import copy
|
|
import logging
|
|
import platform
|
|
import unittest
|
|
|
|
import asyncio
|
|
import mock
|
|
import pyrogram
|
|
import pytest
|
|
|
|
from media_downloader import (
|
|
_get_media_meta,
|
|
_can_download,
|
|
_is_exist,
|
|
download_media,
|
|
update_config,
|
|
begin_import,
|
|
process_messages,
|
|
main,
|
|
)
|
|
|
|
MOCK_DIR: str = "/root/project"
|
|
if platform.system() == "Windows":
|
|
MOCK_DIR = "\\root\\project"
|
|
MOCK_CONF = {
|
|
"api_id": 123,
|
|
"api_hash": "hasw5Tgawsuj67",
|
|
"last_read_message_id": 0,
|
|
"chat_id": 8654123,
|
|
"ids_to_retry": [1],
|
|
"media_types": ["audio", "voice"],
|
|
"file_formats": {"audio": ["all"], "voice": ["all"]},
|
|
}
|
|
|
|
|
|
def platform_generic_path(_path: str) -> str:
|
|
platform_specific_path: str = _path
|
|
if platform.system() == "Windows":
|
|
platform_specific_path = platform_specific_path.replace("/", "\\")
|
|
return platform_specific_path
|
|
|
|
|
|
def mock_manage_duplicate_file(file_path: str) -> str:
|
|
return file_path
|
|
|
|
|
|
class Chat:
|
|
def __init__(self, chat_id):
|
|
self.id = chat_id
|
|
|
|
|
|
class MockMessage:
|
|
def __init__(self, **kwargs):
|
|
self.message_id = kwargs.get("id")
|
|
self.media = kwargs.get("media")
|
|
self.audio = kwargs.get("audio", None)
|
|
self.document = kwargs.get("document", None)
|
|
self.photo = kwargs.get("photo", None)
|
|
self.video = kwargs.get("video", None)
|
|
self.voice = kwargs.get("voice", None)
|
|
self.video_note = kwargs.get("video_note", None)
|
|
self.chat = Chat(kwargs.get("chat_id", None))
|
|
|
|
|
|
class MockAudio:
|
|
def __init__(self, **kwargs):
|
|
self.file_name = kwargs["file_name"]
|
|
self.mime_type = kwargs["mime_type"]
|
|
|
|
|
|
class MockDocument:
|
|
def __init__(self, **kwargs):
|
|
self.file_name = kwargs["file_name"]
|
|
self.mime_type = kwargs["mime_type"]
|
|
|
|
|
|
class MockPhoto:
|
|
def __init__(self, **kwargs):
|
|
self.date = kwargs["date"]
|
|
|
|
|
|
class MockVoice:
|
|
def __init__(self, **kwargs):
|
|
self.mime_type = kwargs["mime_type"]
|
|
self.date = kwargs["date"]
|
|
|
|
|
|
class MockVideo:
|
|
def __init__(self, **kwargs):
|
|
self.mime_type = kwargs["mime_type"]
|
|
|
|
|
|
class MockVideoNote:
|
|
def __init__(self, **kwargs):
|
|
self.mime_type = kwargs["mime_type"]
|
|
self.date = kwargs["date"]
|
|
|
|
|
|
class MockEventLoop:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def run_until_complete(self, *args, **kwargs):
|
|
return {"api_id": 1, "api_hash": "asdf", "ids_to_retry": [1, 2, 3]}
|
|
|
|
|
|
class MockAsync:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_event_loop(self):
|
|
return MockEventLoop()
|
|
|
|
|
|
async def async_get_media_meta(message_media, _type):
|
|
result = await _get_media_meta(message_media, _type)
|
|
return result
|
|
|
|
|
|
async def async_download_media(client, message, media_types, file_formats):
|
|
result = await download_media(client, message, media_types, file_formats)
|
|
return result
|
|
|
|
|
|
async def async_begin_import(conf, pagination_limit):
|
|
result = await begin_import(conf, pagination_limit)
|
|
return result
|
|
|
|
|
|
async def mock_process_message(*args, **kwargs):
|
|
return 5
|
|
|
|
|
|
async def async_process_messages(client, messages, media_types, file_formats):
|
|
result = await process_messages(
|
|
client, messages, media_types, file_formats
|
|
)
|
|
return result
|
|
|
|
|
|
class MockClient:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def start(self):
|
|
pass
|
|
|
|
async def stop(self):
|
|
pass
|
|
|
|
async def iter_history(self, *args, **kwargs):
|
|
items = [
|
|
MockMessage(
|
|
id=1213,
|
|
media=True,
|
|
voice=MockVoice(
|
|
mime_type="audio/ogg",
|
|
date=1564066430,
|
|
),
|
|
),
|
|
MockMessage(
|
|
id=1214,
|
|
media=False,
|
|
text="test message 1",
|
|
),
|
|
MockMessage(
|
|
id=1215,
|
|
media=False,
|
|
text="test message 2",
|
|
),
|
|
MockMessage(
|
|
id=1216,
|
|
media=False,
|
|
text="test message 3",
|
|
),
|
|
]
|
|
for item in items:
|
|
yield item
|
|
|
|
async def get_messages(self, *args, **kwargs):
|
|
if kwargs["message_ids"] == 7:
|
|
message = MockMessage(
|
|
id=7,
|
|
media=True,
|
|
chat_id=123456,
|
|
video=MockVideo(
|
|
file_name="sample_video.mov",
|
|
mime_type="video/mov",
|
|
),
|
|
)
|
|
elif kwargs["message_ids"] == 8:
|
|
message = MockMessage(
|
|
id=8,
|
|
media=True,
|
|
chat_id=234567,
|
|
video=MockVideo(
|
|
file_name="sample_video.mov",
|
|
mime_type="video/mov",
|
|
),
|
|
)
|
|
elif kwargs["message_ids"] == [1]:
|
|
message = [
|
|
MockMessage(
|
|
id=1,
|
|
media=True,
|
|
chat_id=234568,
|
|
video=MockVideo(
|
|
file_name="sample_video.mov",
|
|
mime_type="video/mov",
|
|
),
|
|
)
|
|
]
|
|
return message
|
|
|
|
async def download_media(self, *args, **kwargs):
|
|
mock_message = args[0]
|
|
if mock_message.message_id in [7, 8]:
|
|
raise pyrogram.errors.exceptions.bad_request_400.BadRequest
|
|
elif mock_message.message_id == 9:
|
|
raise pyrogram.errors.exceptions.unauthorized_401.Unauthorized
|
|
elif mock_message.message_id == 11:
|
|
raise TypeError
|
|
return kwargs["file_name"]
|
|
|
|
|
|
class MediaDownloaderTestCase(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.loop = asyncio.get_event_loop()
|
|
|
|
@mock.patch("media_downloader.THIS_DIR", new=MOCK_DIR)
|
|
def test_get_media_meta(self):
|
|
# Test Voice notes
|
|
message = MockMessage(
|
|
id=1,
|
|
media=True,
|
|
voice=MockVoice(
|
|
mime_type="audio/ogg",
|
|
date=1564066430,
|
|
),
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_get_media_meta(message.voice, "voice")
|
|
)
|
|
|
|
self.assertEqual(
|
|
(
|
|
platform_generic_path(
|
|
"/root/project/voice/voice_2019-07-25T14:53:50.ogg"
|
|
),
|
|
"ogg",
|
|
),
|
|
result,
|
|
)
|
|
|
|
# Test photos
|
|
message = MockMessage(
|
|
id=2,
|
|
media=True,
|
|
photo=MockPhoto(date=1565015712),
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_get_media_meta(message.photo, "photo")
|
|
)
|
|
self.assertEqual(
|
|
(
|
|
platform_generic_path("/root/project/photo/"),
|
|
None,
|
|
),
|
|
result,
|
|
)
|
|
|
|
# Test Documents
|
|
message = MockMessage(
|
|
id=3,
|
|
media=True,
|
|
document=MockDocument(
|
|
file_name="sample_document.pdf",
|
|
mime_type="application/pdf",
|
|
),
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_get_media_meta(message.document, "document")
|
|
)
|
|
self.assertEqual(
|
|
(
|
|
platform_generic_path(
|
|
"/root/project/document/sample_document.pdf"
|
|
),
|
|
"pdf",
|
|
),
|
|
result,
|
|
)
|
|
|
|
# Test audio
|
|
message = MockMessage(
|
|
id=4,
|
|
media=True,
|
|
audio=MockAudio(
|
|
file_name="sample_audio.mp3",
|
|
mime_type="audio/mp3",
|
|
),
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_get_media_meta(message.audio, "audio")
|
|
)
|
|
self.assertEqual(
|
|
(
|
|
platform_generic_path("/root/project/audio/sample_audio.mp3"),
|
|
"mp3",
|
|
),
|
|
result,
|
|
)
|
|
|
|
# Test Video
|
|
message = MockMessage(
|
|
id=5,
|
|
media=True,
|
|
video=MockVideo(
|
|
mime_type="video/mp4",
|
|
),
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_get_media_meta(message.video, "video")
|
|
)
|
|
self.assertEqual(
|
|
(
|
|
platform_generic_path("/root/project/video/"),
|
|
"mp4",
|
|
),
|
|
result,
|
|
)
|
|
|
|
# Test VideoNote
|
|
message = MockMessage(
|
|
id=6,
|
|
media=True,
|
|
video_note=MockVideoNote(
|
|
mime_type="video/mp4",
|
|
date=1564066430,
|
|
),
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_get_media_meta(message.video_note, "video_note")
|
|
)
|
|
self.assertEqual(
|
|
(
|
|
platform_generic_path(
|
|
"/root/project/video_note/video_note_2019-07-25T14:53:50.mp4"
|
|
),
|
|
"mp4",
|
|
),
|
|
result,
|
|
)
|
|
|
|
@mock.patch("media_downloader.THIS_DIR", new=MOCK_DIR)
|
|
@mock.patch("media_downloader.asyncio.sleep", return_value=None)
|
|
@mock.patch("media_downloader.logger")
|
|
def test_download_media(self, mock_logger, patched_time_sleep):
|
|
client = MockClient()
|
|
message = MockMessage(
|
|
id=5,
|
|
media=True,
|
|
video=MockVideo(
|
|
file_name="sample_video.mp4",
|
|
mime_type="video/mp4",
|
|
),
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_download_media(
|
|
client, message, ["video", "photo"], {"video": ["mp4"]}
|
|
)
|
|
)
|
|
self.assertEqual(5, result)
|
|
|
|
message_1 = MockMessage(
|
|
id=6,
|
|
media=True,
|
|
video=MockVideo(
|
|
file_name="sample_video.mov",
|
|
mime_type="video/mov",
|
|
),
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_download_media(
|
|
client, message_1, ["video", "photo"], {"video": ["all"]}
|
|
)
|
|
)
|
|
self.assertEqual(6, result)
|
|
|
|
# Test re-fetch message success
|
|
message_2 = MockMessage(
|
|
id=7,
|
|
media=True,
|
|
video=MockVideo(
|
|
file_name="sample_video.mov",
|
|
mime_type="video/mov",
|
|
),
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_download_media(
|
|
client, message_2, ["video", "photo"], {"video": ["all"]}
|
|
)
|
|
)
|
|
self.assertEqual(7, result)
|
|
mock_logger.warning.assert_called_with(
|
|
"Message[%d]: file reference expired, refetching...", 7
|
|
)
|
|
|
|
# Test re-fetch message failure
|
|
message_3 = MockMessage(
|
|
id=8,
|
|
media=True,
|
|
video=MockVideo(
|
|
file_name="sample_video.mov",
|
|
mime_type="video/mov",
|
|
),
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_download_media(
|
|
client, message_3, ["video", "photo"], {"video": ["all"]}
|
|
)
|
|
)
|
|
self.assertEqual(8, result)
|
|
mock_logger.error.assert_called_with(
|
|
"Message[%d]: file reference expired for 3 retries, download skipped.",
|
|
8,
|
|
)
|
|
|
|
# Test other exception
|
|
message_4 = MockMessage(
|
|
id=9,
|
|
media=True,
|
|
video=MockVideo(
|
|
file_name="sample_video.mov",
|
|
mime_type="video/mov",
|
|
),
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_download_media(
|
|
client, message_4, ["video", "photo"], {"video": ["all"]}
|
|
)
|
|
)
|
|
self.assertEqual(9, result)
|
|
mock_logger.error.assert_called_with(
|
|
"Message[%d]: could not be downloaded due to following exception:\n[%s].",
|
|
9,
|
|
mock.ANY,
|
|
exc_info=True,
|
|
)
|
|
|
|
# Check no media
|
|
message_5 = MockMessage(
|
|
id=10,
|
|
media=None,
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_download_media(
|
|
client, message_5, ["video", "photo"], {"video": ["all"]}
|
|
)
|
|
)
|
|
self.assertEqual(10, result)
|
|
|
|
# Test timeout
|
|
message_6 = MockMessage(
|
|
id=11,
|
|
media=True,
|
|
video=MockVideo(
|
|
file_name="sample_video.mov",
|
|
mime_type="video/mov",
|
|
),
|
|
)
|
|
result = self.loop.run_until_complete(
|
|
async_download_media(
|
|
client, message_6, ["video", "photo"], {"video": ["all"]}
|
|
)
|
|
)
|
|
self.assertEqual(11, result)
|
|
mock_logger.error.assert_called_with(
|
|
"Message[%d]: Timing out after 3 reties, download skipped.", 11
|
|
)
|
|
|
|
@mock.patch("__main__.__builtins__.open", new_callable=mock.mock_open)
|
|
@mock.patch("media_downloader.yaml", autospec=True)
|
|
def test_update_config(self, mock_yaml, mock_open):
|
|
conf = {
|
|
"api_id": 123,
|
|
"api_hash": "hasw5Tgawsuj67",
|
|
"ids_to_retry": [],
|
|
}
|
|
update_config(conf)
|
|
mock_open.assert_called_with("config.yaml", "w")
|
|
mock_yaml.dump.assert_called_with(
|
|
conf, mock.ANY, default_flow_style=False
|
|
)
|
|
|
|
@mock.patch("media_downloader.update_config")
|
|
@mock.patch("media_downloader.pyrogram.Client", new=MockClient)
|
|
@mock.patch("media_downloader.process_messages", new=mock_process_message)
|
|
def test_begin_import(self, mock_update_config):
|
|
result = self.loop.run_until_complete(async_begin_import(MOCK_CONF, 3))
|
|
conf = copy.deepcopy(MOCK_CONF)
|
|
conf["last_read_message_id"] = 5
|
|
self.assertDictEqual(result, conf)
|
|
|
|
def test_process_message(self):
|
|
client = MockClient()
|
|
result = self.loop.run_until_complete(
|
|
async_process_messages(
|
|
client,
|
|
[
|
|
MockMessage(
|
|
id=1213,
|
|
media=True,
|
|
voice=MockVoice(
|
|
mime_type="audio/ogg",
|
|
date=1564066340,
|
|
),
|
|
),
|
|
MockMessage(
|
|
id=1214,
|
|
media=False,
|
|
text="test message 1",
|
|
),
|
|
MockMessage(
|
|
id=1215,
|
|
media=False,
|
|
text="test message 2",
|
|
),
|
|
MockMessage(
|
|
id=1216,
|
|
media=False,
|
|
text="test message 3",
|
|
),
|
|
],
|
|
["voice", "photo"],
|
|
{"audio": ["all"], "voice": ["all"]},
|
|
)
|
|
)
|
|
self.assertEqual(result, 1216)
|
|
|
|
@mock.patch("media_downloader._is_exist", return_value=True)
|
|
@mock.patch(
|
|
"media_downloader.manage_duplicate_file",
|
|
new=mock_manage_duplicate_file,
|
|
)
|
|
def test_process_message_when_file_exists(self, mock_is_exist):
|
|
client = MockClient()
|
|
result = self.loop.run_until_complete(
|
|
async_process_messages(
|
|
client,
|
|
[
|
|
MockMessage(
|
|
id=1213,
|
|
media=True,
|
|
voice=MockVoice(
|
|
mime_type="audio/ogg",
|
|
date=1564066340,
|
|
),
|
|
),
|
|
MockMessage(
|
|
id=1214,
|
|
media=False,
|
|
text="test message 1",
|
|
),
|
|
MockMessage(
|
|
id=1215,
|
|
media=False,
|
|
text="test message 2",
|
|
),
|
|
MockMessage(
|
|
id=1216,
|
|
media=False,
|
|
text="test message 3",
|
|
),
|
|
],
|
|
["voice", "photo"],
|
|
{"audio": ["all"], "voice": ["all"]},
|
|
)
|
|
)
|
|
self.assertEqual(result, 1216)
|
|
|
|
def test_can_download(self):
|
|
file_formats = {
|
|
"audio": ["mp3"],
|
|
"video": ["mp4"],
|
|
"document": ["all"],
|
|
}
|
|
result = _can_download("audio", file_formats, "mp3")
|
|
self.assertEqual(result, True)
|
|
|
|
result1 = _can_download("audio", file_formats, "ogg")
|
|
self.assertEqual(result1, False)
|
|
|
|
result2 = _can_download("document", file_formats, "pdf")
|
|
self.assertEqual(result2, True)
|
|
|
|
result3 = _can_download("document", file_formats, "epub")
|
|
self.assertEqual(result3, True)
|
|
|
|
def test_is_exist(self):
|
|
this_dir = os.path.dirname(os.path.abspath(__file__))
|
|
result = _is_exist(os.path.join(this_dir, "__init__.py"))
|
|
self.assertEqual(result, True)
|
|
|
|
result1 = _is_exist(os.path.join(this_dir, "init.py"))
|
|
self.assertEqual(result1, False)
|
|
|
|
result2 = _is_exist(this_dir)
|
|
self.assertEqual(result2, False)
|
|
|
|
@mock.patch("media_downloader.FAILED_IDS", [2, 3])
|
|
@mock.patch("media_downloader.yaml.safe_load")
|
|
@mock.patch("media_downloader.update_config", return_value=True)
|
|
@mock.patch("media_downloader.begin_import")
|
|
@mock.patch("media_downloader.asyncio", new=MockAsync())
|
|
def test_main(self, mock_import, mock_update, mock_yaml):
|
|
conf = {
|
|
"api_id": 1,
|
|
"api_hash": "asdf",
|
|
"ids_to_retry": [1, 2],
|
|
}
|
|
mock_yaml.return_value = conf
|
|
main()
|
|
mock_import.assert_called_with(conf, pagination_limit=100)
|
|
conf["ids_to_retry"] = [1, 2, 3]
|
|
mock_update.assert_called_with(conf)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.loop.close()
|