diff --git a/Makefile b/Makefile index e1bda15..d12c1ea 100644 --- a/Makefile +++ b/Makefile @@ -11,13 +11,14 @@ install: pip install https://github.com/pyrogram/pyrogram/archive/asyncio.zip static_type_check: - mypy media_downloader.py --ignore-missing-imports + mypy media_downloader.py utils --ignore-missing-imports pylint: - pylint media_downloader.py -r y + pylint media_downloader.py utils -r y test: py.test --cov media_downloader --doctest-modules \ + --cov utils \ --cov-report term-missing \ --cov-report html:${TEST_ARTIFACTS} \ --junit-xml=${TEST_ARTIFACTS}/media-downloader.xml \ diff --git a/media_downloader.py b/media_downloader.py index 7b396ea..0da292a 100644 --- a/media_downloader.py +++ b/media_downloader.py @@ -8,6 +8,7 @@ import asyncio import pyrogram import yaml +from utils.file_management import get_next_name, manage_duplicate_file logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -28,6 +29,22 @@ def update_config(config: dict): logger.info("Updated last read message_id to config file") +def _can_download( + _type: str, file_formats: dict, file_format: Optional[str] +) -> bool: + """Check if the given file format can be downloaded""" + if _type in ["audio", "document", "video"]: + allowed_formats: list = file_formats[_type] + if not file_format in allowed_formats and allowed_formats[0] != "all": + return False + return True + + +def _is_exist(file_path: str) -> bool: + """Check if a file exists and it is not a directory""" + return not os.path.isdir(file_path) and os.path.exists(file_path) + + async def _get_media_meta( media_obj: pyrogram.client.types.messages_and_media, _type: str ) -> Tuple[str, str, Optional[str]]: @@ -101,17 +118,6 @@ async def download_media( integer message_id """ - - def _can_download(_type, file_formats, file_format): - if _type in ["audio", "document", "video"]: - allowed_formats: list = file_formats[_type] - if ( - not file_format in allowed_formats - and allowed_formats[0] != "all" - ): - return False - return True - if message.media: for _type in media_types: _media = getattr(message, _type, None) @@ -120,9 +126,16 @@ async def download_media( _media, _type ) if _can_download(_type, file_formats, file_format): - download_path = await client.download_media( - message, file_ref=file_ref, file_name=file_name - ) + if _is_exist(file_name): + file_name = get_next_name(file_name) + download_path = await client.download_media( + message, file_ref=file_ref, file_name=file_name + ) + download_path = manage_duplicate_file(download_path) + else: + download_path = await client.download_media( + message, file_ref=file_ref, file_name=file_name + ) logger.info("Media downloaded - %s", download_path) return message.message_id diff --git a/tests/test_media_downloader.py b/tests/test_media_downloader.py index 0870e3b..a13700b 100644 --- a/tests/test_media_downloader.py +++ b/tests/test_media_downloader.py @@ -11,6 +11,8 @@ import asyncio from media_downloader import ( _get_media_meta, + _can_download, + _is_exist, download_media, update_config, begin_import, @@ -37,6 +39,10 @@ def platform_generic_path(_path: str) -> str: return platform_specific_path +def mock_manage_duplicate_file(file_path: str) -> str: + return file_path + + class MockMessage: def __init__(self, **kwargs): self.message_id = kwargs.get("id") @@ -330,7 +336,7 @@ class MediaDownloaderTestCase(unittest.TestCase): voice=MockVoice( file_ref="AwADBQADbwAD2oTRVeHe5eXRFftfAg", mime_type="audio/ogg", - date=1564066430, + date=1564066340, ), ), MockMessage(id=1214, media=False, text="test message 1",), @@ -343,6 +349,67 @@ class MediaDownloaderTestCase(unittest.TestCase): ) self.assertEqual(result, 1216) + @mock.patch("media_downloader._is_exist", return_value=True) + @mock.patch( + "media_downloader.manage_duplicate_file", + new=mock_manage_duplicate_file, + ) + def test_process_message_when_file_exists(self, mock_is_exist): + client = MockClient() + result = self.loop.run_until_complete( + async_process_messages( + client, + [ + MockMessage( + id=1213, + media=True, + voice=MockVoice( + file_ref="AwADBQADbwAD2oTRVeHe5eXRFftfAg", + mime_type="audio/ogg", + date=1564066340, + ), + ), + MockMessage(id=1214, media=False, text="test message 1",), + MockMessage(id=1215, media=False, text="test message 2",), + MockMessage(id=1216, media=False, text="test message 3",), + ], + ["voice", "photo"], + {"audio": ["all"], "voice": ["all"]}, + ) + ) + self.assertEqual(result, 1216) + + def test_can_download(self): + file_formats = { + "audio": ["mp3"], + "video": ["mp4"], + "document": ["all"], + } + result = _can_download("audio", file_formats, "mp3") + self.assertEqual(result, True) + + result1 = _can_download("audio", file_formats, "ogg") + self.assertEqual(result1, False) + + result2 = _can_download("document", file_formats, "pdf") + self.assertEqual(result2, True) + + result3 = _can_download("document", file_formats, "epub") + self.assertEqual(result3, True) + + def test_is_exist(self): + this_dir = os.path.dirname(os.path.abspath(__file__)) + result = _is_exist(os.path.join(this_dir, "__init__.py")) + self.assertEqual(result, True) + + result1 = _is_exist(os.path.join(this_dir, "init.py")) + self.assertEqual(result1, False) + + result2 = _is_exist(this_dir) + self.assertEqual(result2, False) + + + @classmethod def tearDownClass(cls): cls.loop.close() diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/test-copy1.txt b/tests/utils/test-copy1.txt new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/test_file_management.py b/tests/utils/test_file_management.py new file mode 100644 index 0000000..7ac188e --- /dev/null +++ b/tests/utils/test_file_management.py @@ -0,0 +1,40 @@ +"""Unittest module for media downloader.""" +import os +import sys +import tempfile +import unittest +from pathlib import Path + +import mock + +sys.path.append("..") # Adds higher directory to python modules path. +from utils.file_management import get_next_name, manage_duplicate_file + + +class FileManagementTestCase(unittest.TestCase): + def setUp(self): + self.this_dir = os.path.dirname(os.path.abspath(__file__)) + self.test_file = os.path.join(self.this_dir, "file-test.txt") + self.test_file_copy_1 = os.path.join(self.this_dir, "file-test-copy1.txt") + self.test_file_copy_2 = os.path.join(self.this_dir, "file-test-copy2.txt") + f = open(self.test_file, "w+") + f.write("dummy file") + f.close() + Path(self.test_file_copy_1).touch() + Path(self.test_file_copy_2).touch() + + def test_get_next_name(self): + result = get_next_name(self.test_file) + excepted_result = os.path.join(self.this_dir, "file-test-copy3.txt") + self.assertEqual(result, excepted_result) + + def test_manage_duplicate_file(self): + result = manage_duplicate_file(self.test_file_copy_2) + self.assertEqual(result, self.test_file_copy_1) + + result1 = manage_duplicate_file(self.test_file_copy_1) + self.assertEqual(result1, self.test_file_copy_1) + + def tearDown(self): + os.remove(self.test_file) + os.remove(self.test_file_copy_1) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/file_management.py b/utils/file_management.py new file mode 100644 index 0000000..7590351 --- /dev/null +++ b/utils/file_management.py @@ -0,0 +1,48 @@ +"""Utility functions to handle downloaded files.""" +import glob +import os +import pathlib +from hashlib import md5 + + +def get_next_name(file_path: str) -> str: + """Returns the next available name to download file.""" + posix_path = pathlib.Path(file_path) + counter: int = 1 + new_file_name: str = "{0}/{1}-copy{2}{3}" + while os.path.isfile( + new_file_name.format( + posix_path.parent, + posix_path.stem, + counter, + "".join(posix_path.suffixes), + ) + ): + counter += 1 + return new_file_name.format( + posix_path.parent, + posix_path.stem, + counter, + "".join(posix_path.suffixes), + ) + + +def manage_duplicate_file(file_path: str): + """ + Check if a file is duplicate. + + Compare the md5 of files with copy name pattern + and remove if the md5 hash is same. + """ + posix_path = pathlib.Path(file_path) + file_base_name: str = "".join(posix_path.stem.split("-copy")[0:-1]) + name_pattern: str = f"{posix_path.parent}/{file_base_name}*" + old_files: list = glob.glob(name_pattern) + old_files.remove(file_path) + current_file_md5: str = md5(open(file_path, "rb").read()).hexdigest() + for old_file_path in old_files: + old_file_md5: str = md5(open(old_file_path, "rb").read()).hexdigest() + if current_file_md5 == old_file_md5: + os.remove(file_path) + return old_file_path + return file_path