diff --git a/upgrade_server/upgrade_server/api.py b/upgrade_server/upgrade_server/api.py index f77fac5..23c389a 100644 --- a/upgrade_server/upgrade_server/api.py +++ b/upgrade_server/upgrade_server/api.py @@ -134,7 +134,7 @@ def get_user_permissions(username: str): @app.get("/") async def api_root(auth=Depends(get_current_username)): - if crud.is_enabled(): + if crud.is_enabled(config_dict): return { "app": __appname__, } @@ -144,7 +144,7 @@ async def api_root(auth=Depends(get_current_username)): @app.get("/status") async def api_status(): - if crud.is_enabled(): + if crud.is_enabled(config_dict): return { "app": "running", } @@ -229,7 +229,7 @@ async def current_version( detail="User does not have permission to access this resource", ) - if not crud.is_enabled(): + if not crud.is_enabled(config_dict): return CurrentVersion(version="0.00") try: @@ -242,7 +242,7 @@ async def current_version( installed_version=installed_version, group=group, ) - result = crud.get_current_version(target_id) + result = crud.get_current_version(config_dict, target_id) if not result: raise HTTPException(status_code=404, detail="Not found") return result @@ -333,7 +333,7 @@ async def upgrades( detail="User does not have permission to access this resource", ) - if not crud.is_enabled(): + if not crud.is_enabled(config_dict): raise HTTPException( status_code=503, detail="Service is currently disabled for maintenance" ) @@ -349,7 +349,7 @@ async def upgrades( group=group, ) try: - result = crud.get_file(file) + result = crud.get_file(config_dict, file) if not result: raise HTTPException(status_code=404, detail="Not found") return result @@ -440,7 +440,7 @@ async def download( detail="User does not have permission to access this resource", ) - if not crud.is_enabled(): + if not crud.is_enabled(config_dict): raise HTTPException( status_code=503, detail="Service is currently disabled for maintenance" ) @@ -456,7 +456,7 @@ async def download( group=group, ) try: - result = crud.get_file(file, content=True) + result = crud.get_file(config_dict, file, content=True) if not result: raise HTTPException(status_code=404, detail="Not found") headers = {"Content-Disposition": 'attachment; filename="npbackup"'} diff --git a/upgrade_server/upgrade_server/crud.py b/upgrade_server/upgrade_server/crud.py index 5112814..70ac873 100644 --- a/upgrade_server/upgrade_server/crud.py +++ b/upgrade_server/upgrade_server/crud.py @@ -14,29 +14,9 @@ import os from typing import Optional, Union, Tuple from logging import getLogger import hashlib -from argparse import ArgumentParser from datetime import datetime, timezone from upgrade_server.models.files import ClientTargetIdentification, FileGet, FileSend from upgrade_server.models.oper import CurrentVersion -import upgrade_server.configuration as configuration - - -# Make sure we load given config files again -parser = ArgumentParser() -parser.add_argument( - "-c", - "--config-file", - dest="config_file", - type=str, - default=None, - required=False, - help="Path to upgrade_server.conf file", -) -args = parser.parse_args() -if args.config_file: - config_dict = configuration.load_config(args.config_file) -else: - config_dict = configuration.load_config() logger = getLogger() @@ -52,12 +32,12 @@ def sha256sum_data(data): return sha256.hexdigest() -def is_enabled() -> bool: +def is_enabled(config_dict) -> bool: path = os.path.join(config_dict["upgrades"]["data_root"], "DISABLED") return not os.path.isfile(path) -def _get_path_from_target_id(target_id: ClientTargetIdentification) -> Tuple[str, str]: +def _get_path_from_target_id(config_dict, target_id: ClientTargetIdentification) -> Tuple[str, str]: """ Determine specific or generic upgrade path depending on target_id sent by client @@ -123,10 +103,11 @@ def store_host_info(destination: str, host_id: dict) -> None: def get_current_version( + config_dict: dict, target_id: ClientTargetIdentification, ) -> Optional[CurrentVersion]: try: - version_filename, _, _ = _get_path_from_target_id(target_id) + version_filename, _, _ = _get_path_from_target_id(config_dict, target_id) logger.info(f"Searching for version in {version_filename}") if os.path.isfile(version_filename): with open(version_filename, "r", encoding="utf-8") as fh: @@ -141,10 +122,11 @@ def get_current_version( def get_file( + config_dict: dict, file: FileGet, content: bool = False ) -> Optional[Union[FileSend, bytes, bool]]: - _, archive_path, script_path = _get_path_from_target_id(file) + _, archive_path, script_path = _get_path_from_target_id(config_dict, file) unknown_artefact = FileSend( artefact=file.artefact.value,