From 895a1ac3341750a40c7328acc1cf974b6eb37316 Mon Sep 17 00:00:00 2001 From: deajan Date: Sat, 25 Jan 2025 13:10:36 +0100 Subject: [PATCH] upgrade_server: Allow generic upgrade scripts --- upgrade_server/upgrade_server/crud.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/upgrade_server/upgrade_server/crud.py b/upgrade_server/upgrade_server/crud.py index 145f776..14c6981 100644 --- a/upgrade_server/upgrade_server/crud.py +++ b/upgrade_server/upgrade_server/crud.py @@ -39,7 +39,7 @@ def is_enabled(config_dict) -> bool: def _get_path_from_target_id( config_dict, target_id: ClientTargetIdentification -) -> Tuple[str, str]: +) -> Tuple[str, str, str]: """ Determine specific or generic upgrade path depending on target_id sent by client @@ -64,6 +64,7 @@ def _get_path_from_target_id( expected_archive_filename = f"npbackup-{target_id.platform.value}-{target_id.arch.value}-{target_id.build_type.value}-{target_id.audience.value}.{archive_extension}" expected_script_filename = f"npbackup-{target_id.platform.value}-{target_id.arch.value}-{target_id.build_type.value}-{target_id.audience.value}.{script_extension}" + expected_short_script_filename = f"npbackup-{target_id.platform.value}.{script_extension}" base_path = os.path.join( config_dict["upgrades"]["data_root"], @@ -83,10 +84,11 @@ def _get_path_from_target_id( archive_path = os.path.join(base_path, expected_archive_filename) script_path = os.path.join(base_path, expected_script_filename) + short_script_path = os.path.join(base_path, expected_short_script_filename) version_file_path = os.path.join(base_path, "VERSION") - return version_file_path, archive_path, script_path + return version_file_path, archive_path, script_path, short_script_path def store_host_info(destination: str, host_id: dict) -> None: @@ -109,7 +111,7 @@ def get_current_version( target_id: ClientTargetIdentification, ) -> Optional[CurrentVersion]: try: - version_filename, _, _ = _get_path_from_target_id(config_dict, target_id) + version_filename, _, _, _ = _get_path_from_target_id(config_dict, target_id) logger.info(f"Searching for version in {version_filename}") if os.path.isfile(version_filename): with open(version_filename, "r", encoding="utf-8") as fh: @@ -127,7 +129,7 @@ def get_file( config_dict: dict, file: FileGet, content: bool = False ) -> Optional[Union[FileSend, bytes, bool]]: - _, archive_path, script_path = _get_path_from_target_id(config_dict, file) + _, archive_path, script_path, short_script_path = _get_path_from_target_id(config_dict, file) unknown_artefact = FileSend( artefact=file.artefact.value, @@ -152,6 +154,12 @@ def get_file( ) if not os.path.isfile(artefact_path): logger.info(f"No {file.artefact.value} file found in {artefact_path}") + if file.artefact.value == "script": + artefact_path = short_script_path + if not os.path.isfile(artefact_path): + logger.info(f"No {file.artefact.value} file found in {artefact_path}") + if content: + return False if content: return False return unknown_artefact