diff --git a/tests/test_media_downloader.py b/tests/test_media_downloader.py index a13700b..d8a958c 100644 --- a/tests/test_media_downloader.py +++ b/tests/test_media_downloader.py @@ -5,9 +5,10 @@ import logging import platform import unittest -import mock -import pytest import asyncio +import mock +import pyrogram +import pytest from media_downloader import ( _get_media_meta, @@ -17,6 +18,7 @@ from media_downloader import ( update_config, begin_import, process_messages, + main, ) MOCK_DIR: str = "/root/project" @@ -43,6 +45,11 @@ def mock_manage_duplicate_file(file_path: str) -> str: return file_path +class Chat: + def __init__(self, chat_id): + self.id = chat_id + + class MockMessage: def __init__(self, **kwargs): self.message_id = kwargs.get("id") @@ -52,6 +59,7 @@ class MockMessage: self.photo = kwargs.get("photo", None) self.video = kwargs.get("video", None) self.voice = kwargs.get("voice", None) + self.chat = Chat(kwargs.get("chat_id", None)) class MockAudio: @@ -87,6 +95,22 @@ class MockVideo: self.mime_type = kwargs["mime_type"] +class MockEventLoop: + def __init__(self): + pass + + def run_until_complete(self, *args, **kwargs): + return {"api_id": 1, "api_hash": "asdf", "ids_to_retry": [1]} + + +class MockAsync: + def __init__(self): + pass + + def get_event_loop(self): + return MockEventLoop() + + async def async_get_media_meta(message_media, _type): result = await _get_media_meta(message_media, _type) return result @@ -137,18 +161,61 @@ class MockClient: date=1564066430, ), ), - MockMessage(id=1214, media=False, text="test message 1",), - MockMessage(id=1215, media=False, text="test message 2",), - MockMessage(id=1216, media=False, text="test message 3",), + MockMessage( + id=1214, + media=False, + text="test message 1", + ), + MockMessage( + id=1215, + media=False, + text="test message 2", + ), + MockMessage( + id=1216, + media=False, + text="test message 3", + ), ] for item in items: yield item + async def get_messages(self, *args, **kwargs): + if kwargs["message_ids"] == 7: + message = MockMessage( + id=7, + media=True, + chat_id=123456, + video=MockVideo( + file_ref="DwAD94854dd3d5eBe322f4a4DEf22872", + file_name="sample_video.mov", + mime_type="video/mov", + ), + ) + elif kwargs["message_ids"] == 8: + message = MockMessage( + id=8, + media=True, + chat_id=234567, + video=MockVideo( + file_ref="QNzmM3Ww2c00sXhWr4ZJwNT77qaxxP19", + file_name="sample_video.mov", + mime_type="video/mov", + ), + ) + return message + async def download_media(self, *args, **kwargs): - assert "AwADBQADbwAD2oTRVeHe5eXRFftfAg", args[0] + assert "AwADBQADbwAD2oTRVeHe5eXRFftfAg", kwargs[0] assert platform_generic_path( "/root/project/voice/voice_2019-07-25T14:53:50.ogg" ), kwargs["file_name"] + if kwargs["file_ref"] == "QNzmM3Ww2c00sXhWr4ZJwNT77qaxxP19": + raise pyrogram.errors.exceptions.bad_request_400.BadRequest + elif kwargs["file_ref"] == "LGmJOmVpbHbrtmDdzKQx5omdZNq7QNJp": + raise pyrogram.errors.exceptions.unauthorized_401.Unauthorized + elif kwargs["file_ref"] == "sJp5vGa02p1p9bkpU1tVx3OkH2x8cxHK": + raise TypeError return kwargs["file_name"] @@ -272,7 +339,9 @@ class MediaDownloaderTestCase(unittest.TestCase): ) @mock.patch("media_downloader.THIS_DIR", new=MOCK_DIR) - def test_download_media(self): + @mock.patch("media_downloader.asyncio.sleep", return_value=None) + @mock.patch("media_downloader.logger") + def test_download_media(self, mock_logger, patched_time_sleep): client = MockClient() message = MockMessage( id=5, @@ -306,6 +375,102 @@ class MediaDownloaderTestCase(unittest.TestCase): ) self.assertEqual(6, result) + # Test re-fetch message success + message_2 = MockMessage( + id=7, + media=True, + video=MockVideo( + file_ref="QNzmM3Ww2c00sXhWr4ZJwNT77qaxxP19", + file_name="sample_video.mov", + mime_type="video/mov", + ), + ) + result = self.loop.run_until_complete( + async_download_media( + client, message_2, ["video", "photo"], {"video": ["all"]} + ) + ) + self.assertEqual(7, result) + mock_logger.warning.assert_called_with( + "Message[%d]: file reference expired, refetching...", 7 + ) + + # Test re-fetch message failure + message_3 = MockMessage( + id=8, + media=True, + video=MockVideo( + file_ref="QNzmM3Ww2c00sXhWr4ZJwNT77qaxxP19", + file_name="sample_video.mov", + mime_type="video/mov", + ), + ) + result = self.loop.run_until_complete( + async_download_media( + client, message_3, ["video", "photo"], {"video": ["all"]} + ) + ) + self.assertEqual(8, result) + mock_logger.error.assert_called_with( + "Message[%d]: file reference expired for 3 retries, download skipped.", + 8, + ) + + # Test other exception + message_4 = MockMessage( + id=9, + media=True, + video=MockVideo( + file_ref="LGmJOmVpbHbrtmDdzKQx5omdZNq7QNJp", + file_name="sample_video.mov", + mime_type="video/mov", + ), + ) + result = self.loop.run_until_complete( + async_download_media( + client, message_4, ["video", "photo"], {"video": ["all"]} + ) + ) + self.assertEqual(9, result) + mock_logger.error.assert_called_with( + "Message[%d]: could not be downloaded due to following exception:\n[%s].", + 9, + mock.ANY, + exc_info=True, + ) + + # Check no media + message_5 = MockMessage( + id=10, + media=None, + ) + result = self.loop.run_until_complete( + async_download_media( + client, message_5, ["video", "photo"], {"video": ["all"]} + ) + ) + self.assertEqual(10, result) + + # Test timeout + message_6 = MockMessage( + id=11, + media=True, + video=MockVideo( + file_ref="sJp5vGa02p1p9bkpU1tVx3OkH2x8cxHK", + file_name="sample_video.mov", + mime_type="video/mov", + ), + ) + result = self.loop.run_until_complete( + async_download_media( + client, message_6, ["video", "photo"], {"video": ["all"]} + ) + ) + self.assertEqual(11, result) + mock_logger.error.assert_called_with( + "Message[%d]: Timing out after 3 reties, download skipped.", 11 + ) + @mock.patch("__main__.__builtins__.open", new_callable=mock.mock_open) @mock.patch("media_downloader.yaml", autospec=True) def test_update_config(self, mock_yaml, mock_open): @@ -339,9 +504,21 @@ class MediaDownloaderTestCase(unittest.TestCase): date=1564066340, ), ), - MockMessage(id=1214, media=False, text="test message 1",), - MockMessage(id=1215, media=False, text="test message 2",), - MockMessage(id=1216, media=False, text="test message 3",), + MockMessage( + id=1214, + media=False, + text="test message 1", + ), + MockMessage( + id=1215, + media=False, + text="test message 2", + ), + MockMessage( + id=1216, + media=False, + text="test message 3", + ), ], ["voice", "photo"], {"audio": ["all"], "voice": ["all"]}, @@ -369,9 +546,21 @@ class MediaDownloaderTestCase(unittest.TestCase): date=1564066340, ), ), - MockMessage(id=1214, media=False, text="test message 1",), - MockMessage(id=1215, media=False, text="test message 2",), - MockMessage(id=1216, media=False, text="test message 3",), + MockMessage( + id=1214, + media=False, + text="test message 1", + ), + MockMessage( + id=1215, + media=False, + text="test message 2", + ), + MockMessage( + id=1216, + media=False, + text="test message 3", + ), ], ["voice", "photo"], {"audio": ["all"], "voice": ["all"]}, @@ -396,7 +585,7 @@ class MediaDownloaderTestCase(unittest.TestCase): result3 = _can_download("document", file_formats, "epub") self.assertEqual(result3, True) - + def test_is_exist(self): this_dir = os.path.dirname(os.path.abspath(__file__)) result = _is_exist(os.path.join(this_dir, "__init__.py")) @@ -408,7 +597,22 @@ class MediaDownloaderTestCase(unittest.TestCase): result2 = _is_exist(this_dir) self.assertEqual(result2, False) - + @mock.patch("media_downloader.FAILED_IDS", [2, 3]) + @mock.patch("media_downloader.yaml.safe_load") + @mock.patch("media_downloader.update_config", return_value=True) + @mock.patch("media_downloader.begin_import") + @mock.patch("media_downloader.asyncio", new=MockAsync()) + def test_main(self, mock_import, mock_update, mock_yaml): + conf = { + "api_id": 1, + "api_hash": "asdf", + "ids_to_retry": [1, 2], + } + mock_yaml.return_value = conf + main() + mock_import.assert_called_with(conf, pagination_limit=100) + conf["ids_to_retry"] = [1, 2, 3] + mock_update.assert_called_with(conf) @classmethod def tearDownClass(cls):