From 352db785546f09051549b56b5c738cf2b3920388 Mon Sep 17 00:00:00 2001 From: "Dr.-Ing. Amilcar do Carmo Lucas" Date: Sat, 28 Dec 2024 18:29:48 +0100 Subject: [PATCH] FEATURE: check for software updates on the github server This is still work-in-progress --- .github/workflows/pylint.yml | 2 +- README.md | 1 - .../annotate_params.py | 2 +- .../backend_filesystem.py | 37 ++- .../backend_internet.py | 195 ++++++++++++ .../frontend_tkinter_parameter_editor.py | 7 +- .../frontend_tkinter_software_update.py | 91 ++++++ .../middleware_software_updates.py | 114 +++++++ pyproject.toml | 1 + tests/test_annotate_params.py | 62 +++- tests/test_backend_filesystem.py | 287 ++++++++++++----- tests/test_backend_internet.py | 292 ++++++++++++++++++ 12 files changed, 994 insertions(+), 97 deletions(-) create mode 100644 ardupilot_methodic_configurator/backend_internet.py create mode 100644 ardupilot_methodic_configurator/frontend_tkinter_software_update.py create mode 100755 ardupilot_methodic_configurator/middleware_software_updates.py create mode 100755 tests/test_backend_internet.py diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 4e3deff..ac31e37 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -37,7 +37,7 @@ jobs: - name: Install dependencies # these extra packages are required by pylint to validate the python imports run: | - uv pip install 'pylint>=3.3.2' defusedxml requests pymavlink pillow numpy matplotlib pyserial setuptools pytest + uv pip install 'pylint>=3.3.2' defusedxml requests pymavlink pillow numpy matplotlib pyserial setuptools pytest GitPython - name: Analysing the code with pylint run: | diff --git a/README.md b/README.md index 3cafc74..4b54ebe 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,6 @@ SPDX-License-Identifier: GPL-3.0-or-later | [![md-link-check](https://github.com/ArduPilot/MethodicConfigurator/actions/workflows/markdown-link-check.yml/badge.svg)](https://github.com/ArduPilot/MethodicConfigurator/actions/workflows/markdown-link-check.yml) | | [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?logo=discord&logoColor=white)](https://discord.com/channels/674039678562861068/1308233496535371856) | ![PyPI - Downloads](https://img.shields.io/pypi/dm/ardupilot-methodic-configurator?link=https%3A%2F%2Fpypi.org%2Fproject%2Fardupilot-methodic-configurator%2F) | | - *ArduPilot Methodic Configurator* is a software, developed by ArduPilot developers, that semi-automates a [clear, proven and safe configuration sequence](https://ardupilot.github.io/MethodicConfigurator/TUNING_GUIDE_ArduCopter) for ArduCopter drones. We are working on extending it to [ArduPlane](https://ardupilot.github.io/MethodicConfigurator/TUNING_GUIDE_ArduPlane), diff --git a/ardupilot_methodic_configurator/annotate_params.py b/ardupilot_methodic_configurator/annotate_params.py index 8d2ac41..5bfb90e 100755 --- a/ardupilot_methodic_configurator/annotate_params.py +++ b/ardupilot_methodic_configurator/annotate_params.py @@ -392,7 +392,7 @@ def get_xml_data(base_url: str, directory: str, filename: str, vehicle_type: str url = base_url + filename proxies = get_env_proxies() try: - response = requests_get(url, timeout=5, proxies=proxies) + response = requests_get(url, timeout=5, proxies=proxies) if proxies else requests_get(url, timeout=5) if response.status_code != 200: logging.warning("Remote URL: %s", url) msg = f"HTTP status code {response.status_code}" diff --git a/ardupilot_methodic_configurator/backend_filesystem.py b/ardupilot_methodic_configurator/backend_filesystem.py index 92e73f6..d2cb84c 100644 --- a/ardupilot_methodic_configurator/backend_filesystem.py +++ b/ardupilot_methodic_configurator/backend_filesystem.py @@ -9,6 +9,7 @@ """ # from sys import exit as sys_exit +import hashlib from argparse import ArgumentParser from logging import debug as logging_debug from logging import error as logging_error @@ -18,6 +19,7 @@ from os import listdir as os_listdir from os import path as os_path from os import rename as os_rename +from pathlib import Path from platform import system as platform_system from re import compile as re_compile from shutil import copy2 as shutil_copy2 @@ -25,14 +27,14 @@ from typing import Any, Optional from zipfile import ZipFile -from requests import get as requests_get +from git import Repo +from git.exc import InvalidGitRepositoryError from ardupilot_methodic_configurator import _ from ardupilot_methodic_configurator.annotate_params import ( PARAM_DEFINITION_XML_FILE, Par, format_columns, - get_env_proxies, get_xml_dir, get_xml_url, load_default_param_file, @@ -639,20 +641,25 @@ def get_upload_local_and_remote_filenames(self, selected_file: str) -> tuple[str return "", "" @staticmethod - def download_file_from_url(url: str, local_filename: str, timeout: int = 5) -> bool: - if not url or not local_filename: - logging_error(_("URL or local filename not provided.")) - return False - logging_info(_("Downloading %s from %s"), local_filename, url) - response = requests_get(url, timeout=timeout, proxies=get_env_proxies()) - - if response.status_code == 200: - with open(local_filename, "wb") as file: - file.write(response.content) - return True + def get_git_commit_hash() -> str: + try: + repo = Repo(search_parent_directories=True) + return str(repo.head.object.hexsha[:7]) + except InvalidGitRepositoryError: + # Fallback to reading the git_hash.txt file + git_hash_file = os_path.join(os_path.dirname(__file__), "git_hash.txt") + if os_path.exists(git_hash_file): + with open(git_hash_file, encoding="utf-8") as file: + return file.read().strip() + return "" - logging_error(_("Failed to download the file")) - return False + @staticmethod + def verify_file_hash(file_path: Path, expected_hash: str) -> bool: + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() == expected_hash @staticmethod def add_argparse_arguments(parser: ArgumentParser) -> ArgumentParser: diff --git a/ardupilot_methodic_configurator/backend_internet.py b/ardupilot_methodic_configurator/backend_internet.py new file mode 100644 index 0000000..cc36572 --- /dev/null +++ b/ardupilot_methodic_configurator/backend_internet.py @@ -0,0 +1,195 @@ +""" +Check for software updates and install them if available. + +This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator + +SPDX-FileCopyrightText: 2024-2025 Amilcar Lucas + +SPDX-License-Identifier: GPL-3.0-or-later +""" + +import os +import subprocess +import tempfile +from datetime import datetime, timezone +from logging import error as logging_error +from logging import info as logging_info +from pathlib import Path +from typing import Any, Callable, Optional +from urllib.parse import urljoin + +from requests import HTTPError as requests_HTTPError +from requests import RequestException as requests_RequestException +from requests import Timeout as requests_Timeout +from requests import get as requests_get +from requests.exceptions import RequestException + +from ardupilot_methodic_configurator import _ +from ardupilot_methodic_configurator.backend_filesystem import LocalFilesystem + +# Constants +GITHUB_API_URL_RELEASES = "https://api.github.com/repos/ArduPilot/MethodicConfigurator/releases/" + + +def download_file_from_url( + url: str, local_filename: str, timeout: int = 30, progress_callback: Optional[Callable[[float, str], None]] = None +) -> bool: + if not url or not local_filename: + logging_error(_("URL or local filename not provided.")) + return False + + logging_info(_("Downloading %s from %s"), local_filename, url) + + try: + proxies_dict = { + "http": os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy"), + "https": os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy"), + "no_proxy": os.environ.get("NO_PROXY") or os.environ.get("no_proxy"), + } + + # Remove None values + proxies = {k: v for k, v in proxies_dict.items() if v is not None} + + # Make request with proxy support + response = requests_get( + url, + stream=True, + timeout=timeout, + proxies=proxies, + verify=True, # SSL verification + ) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + downloaded = 0 + + os.makedirs(os.path.dirname(os.path.abspath(local_filename)), exist_ok=True) + + with open(local_filename, "wb") as file: + for chunk in response.iter_content(chunk_size=block_size): + if chunk: + file.write(chunk) + downloaded += len(chunk) + if progress_callback and total_size: + progress = (downloaded / total_size) * 100 + msg = _("Downloading ... {:.1f}%") + progress_callback(progress, msg.format(progress)) + + if progress_callback: + progress_callback(100.0, _("Download complete")) + return True + + except requests_Timeout: + logging_error(_("Download timed out")) + except requests_RequestException as e: + logging_error(_("Network error during download: {}").format(e)) + except OSError as e: + logging_error(_("File system error: {}").format(e)) + except ValueError as e: + logging_error(_("Invalid data received from %s: %s"), url, e) + + return False + + +def get_release_info(name: str, should_be_pre_release: bool, timeout: int = 30) -> dict[str, Any]: + """ + Get release information from GitHub API. + + Args: + name: Release name/path (e.g. '/latest') + should_be_pre_release: Whether the release should be a pre-release + timeout: Request timeout in seconds + + Returns: + Release information dictionary + + Raises: + RequestException: If the request fails + + """ + if not name: + msg = "Release name cannot be empty" + raise ValueError(msg) + + try: + url = urljoin(GITHUB_API_URL_RELEASES, name.lstrip("/")) + response = requests_get(url, timeout=timeout) + response.raise_for_status() + + release_info = response.json() + + if should_be_pre_release and not release_info["prerelease"]: + logging_error(_("The latest continuous delivery build must be a pre-release")) + if not should_be_pre_release and release_info["prerelease"]: + logging_error(_("The latest stable release must not be a pre-release")) + + return release_info # type: ignore[no-any-return] + + except requests_HTTPError as e: + if e.response.status_code == 403: + logging_error(_("Failed to fetch release info: {}").format(e)) + # Get the rate limit reset time + reset_time = int(e.response.headers.get("X-RateLimit-Reset", 0)) + # Create a timezone-aware UTC datetime + reset_datetime = datetime.fromtimestamp(reset_time, timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z") + logging_error(_("Rate limit exceeded. Please try again after: %s (UTC)"), reset_datetime) + raise + except RequestException as e: + logging_error(_("Failed to fetch release info: {}").format(e)) + raise + except (KeyError, ValueError) as e: + logging_error(_("Invalid release data: {}").format(e)) + raise + + +def download_and_install_on_windows( + download_url: str, + file_name: str, + expected_hash: Optional[str] = None, + progress_callback: Optional[Callable[[float, str], None]] = None, +) -> bool: + logging_info(_("Downloading and installing new version for Windows...")) + try: + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = os.path.join(temp_dir, file_name) + + # Download with progress updates + if not download_file_from_url( + download_url, + temp_path, + timeout=60, # Increased timeout for large files + progress_callback=progress_callback, + ): + return False + + if expected_hash and not LocalFilesystem.verify_file_hash(Path(temp_path), expected_hash): + logging_error(_("File hash verification failed")) + return False + + if progress_callback: + progress_callback(100.0, _("Starting installation...")) + + # Run installer + result = subprocess.run( # noqa: S603 + [temp_path], + shell=False, + check=True, + capture_output=True, + text=True, + creationflags=subprocess.CREATE_NO_WINDOW, # type: ignore[attr-defined] + ) + + return result.returncode == 0 + + except subprocess.SubprocessError as e: + logging_error(_("Installation failed: {}").format(e)) + return False + except OSError as e: + logging_error(_("File operation failed: {}").format(e)) + return False + + +def download_and_install_pip_release() -> int: + logging_info(_("Updating via pip for Linux and MacOS...")) + return os.system("pip install --upgrade ardupilot_methodic_configurator") # noqa: S605, S607 diff --git a/ardupilot_methodic_configurator/frontend_tkinter_parameter_editor.py b/ardupilot_methodic_configurator/frontend_tkinter_parameter_editor.py index a6eabdb..f4de1ad 100755 --- a/ardupilot_methodic_configurator/frontend_tkinter_parameter_editor.py +++ b/ardupilot_methodic_configurator/frontend_tkinter_parameter_editor.py @@ -30,6 +30,7 @@ from ardupilot_methodic_configurator.backend_filesystem import LocalFilesystem, is_within_tolerance from ardupilot_methodic_configurator.backend_filesystem_program_settings import ProgramSettings from ardupilot_methodic_configurator.backend_flightcontroller import FlightController +from ardupilot_methodic_configurator.backend_internet import download_file_from_url from ardupilot_methodic_configurator.common_arguments import add_common_arguments_and_parse from ardupilot_methodic_configurator.frontend_tkinter_base import ( AutoResizeCombobox, @@ -459,9 +460,9 @@ def __should_download_file_from_url(self, selected_file: str) -> None: if self.local_filesystem.vehicle_configuration_file_exists(local_filename): return # file already exists in the vehicle directory, no need to download it msg = _("Should the {local_filename} file be downloaded from the URL\n{url}?") - if messagebox.askyesno( - _("Download file from URL"), msg.format(**locals()) - ) and not self.local_filesystem.download_file_from_url(url, local_filename): + if messagebox.askyesno(_("Download file from URL"), msg.format(**locals())) and not download_file_from_url( + url, local_filename + ): error_msg = _("Failed to download {local_filename} from {url}, please download it manually") messagebox.showerror(_("Download failed"), error_msg.format(**locals())) diff --git a/ardupilot_methodic_configurator/frontend_tkinter_software_update.py b/ardupilot_methodic_configurator/frontend_tkinter_software_update.py new file mode 100644 index 0000000..390b98f --- /dev/null +++ b/ardupilot_methodic_configurator/frontend_tkinter_software_update.py @@ -0,0 +1,91 @@ +""" +Check for software updates and install them if available. + +This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator + +SPDX-FileCopyrightText: 2024-2025 Amilcar Lucas + +SPDX-License-Identifier: GPL-3.0-or-later +""" + +import tkinter as tk +from tkinter import ttk +from typing import Callable, Optional + +from ardupilot_methodic_configurator import _, __version__ + + +class UpdateDialog: # pylint: disable=too-many-instance-attributes + """Dialog for displaying software update information and handling user interaction.""" + + def __init__(self, version_info: str, download_callback: Optional[Callable[[], bool]] = None) -> None: + self.root = tk.Tk() + self.root.title(_("Amilcar Lucas's - ArduPilot methodic configurator ") + __version__ + _(" - Update Software")) + self.download_callback = download_callback + self.root.protocol("WM_DELETE_WINDOW", self.on_cancel) + + self.frame = ttk.Frame(self.root, padding="20") + self.frame.grid(sticky="nsew") + + self.msg = ttk.Label(self.frame, text=version_info, wraplength=650, justify="left") + self.msg.grid(row=0, column=0, columnspan=2, pady=20) + + self.progress = ttk.Progressbar(self.frame, orient="horizontal", length=400, mode="determinate") + self.progress.grid(row=1, column=0, columnspan=2, pady=10, padx=10) + self.progress.grid_remove() + + self.status_label = ttk.Label(self.frame, text="") + self.status_label.grid(row=2, column=0, columnspan=2) + + self.result: Optional[bool] = None + self._setup_buttons() + + def _setup_buttons(self) -> None: + self.yes_btn = ttk.Button(self.frame, text=_("Update Now"), command=self.on_yes) + self.no_btn = ttk.Button(self.frame, text=_("Not Now"), command=self.on_no) + self.yes_btn.grid(row=3, column=0, padx=5) + self.no_btn.grid(row=3, column=1, padx=5) + + def update_progress(self, value: float, status: str = "") -> None: + """Update progress directly.""" + self.progress["value"] = value + if status: + self.status_label["text"] = status + self.root.update() + + def on_yes(self) -> None: + self.progress.grid() + self.status_label.grid() + self.yes_btn.config(state="disabled") + self.no_btn.config(state="disabled") + + if self.download_callback: + success = self.download_callback() + if success: + self.status_label["text"] = _("Update complete! Please restart the application.") + self.result = True + else: + self.status_label["text"] = _("Update failed!") + self.yes_btn.config(state="normal") + self.no_btn.config(state="normal") + self.result = False + self.root.after(2000, self.root.destroy) + + def on_no(self) -> None: + self.result = False + self.root.destroy() + + def on_cancel(self) -> None: + self.result = False + self.root.destroy() + + def show(self) -> bool: + """ + Display the update dialog and return user's choice. + + Returns: + bool: True if user chose to update, False otherwise + + """ + self.root.mainloop() + return bool(self.result) diff --git a/ardupilot_methodic_configurator/middleware_software_updates.py b/ardupilot_methodic_configurator/middleware_software_updates.py new file mode 100755 index 0000000..d5d85ba --- /dev/null +++ b/ardupilot_methodic_configurator/middleware_software_updates.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 + +""" +Check for software updates and install them if available. + +This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator + +SPDX-FileCopyrightText: 2024-2025 Amilcar Lucas + +SPDX-License-Identifier: GPL-3.0-or-later +""" + +import platform +from logging import basicConfig as logging_basicConfig +from logging import debug as logging_error +from logging import getLevelName as logging_getLevelName +from logging import info as logging_info +from logging import warning as logging_warning +from typing import Any, Optional + +from requests import RequestException as requests_RequestException + +from ardupilot_methodic_configurator import _ +from ardupilot_methodic_configurator import __version__ as current_version +from ardupilot_methodic_configurator.backend_filesystem import LocalFilesystem +from ardupilot_methodic_configurator.backend_internet import ( + download_and_install_on_windows, + download_and_install_pip_release, + get_release_info, +) +from ardupilot_methodic_configurator.frontend_tkinter_software_update import UpdateDialog + + +def format_version_info(_current_version: str, _latest_release: str, _changes: str) -> str: + return ( + _("New version available!") + + "\n\n" + + _("Current version: {_current_version}") + + "\n" + + _("Latest version: {_latest_release}") + + "\n\n" + + _("Changes:\n{_changes}") + ).format(**locals()) + + +class UpdateManager: # pylint: disable=too-few-public-methods + """Manages the software update process including user interaction and installation.""" + + def __init__(self) -> None: + self.dialog: Optional[UpdateDialog] = None + + def _perform_download(self, latest_release: dict[str, Any]) -> bool: + if platform.system() == "Windows": + try: + asset = latest_release["assets"][0] + return download_and_install_on_windows( + download_url=asset["browser_download_url"], + file_name=asset["name"], + progress_callback=self.dialog.update_progress if self.dialog else None, + ) + except (KeyError, IndexError) as e: + logging_error(_("Error accessing release assets: %s"), e) + return False + return download_and_install_pip_release() == 0 + + def check_and_update(self, latest_release: dict[str, Any], current_version_str: str) -> bool: + try: + latest_version = latest_release["tag_name"].lstrip("v") + if current_version_str == latest_version: + logging_info(_("Already running latest version.")) + return True + + version_info = format_version_info( + current_version_str, latest_version, latest_release.get("body", _("No changes listed")) + ) + self.dialog = UpdateDialog(version_info, download_callback=lambda: self._perform_download(latest_release)) + return self.dialog.show() + + except KeyError as ke: + logging_error(_("Key error during update process: %s"), ke) + return False + except requests_RequestException as req_ex: + logging_error(_("Network error during update process: %s"), req_ex) + return False + except ValueError as val_ex: + logging_error(_("Value error during update process: %s"), val_ex) + return False + + +def check_for_software_updates() -> None: + """Main update orchestration function.""" + git_hash = LocalFilesystem.get_git_commit_hash() + + msg = _("Running version: {} (git hash: {})") + logging_info(msg.format(current_version, git_hash)) + + try: + latest_release = get_release_info("/latest", should_be_pre_release=False) + update_manager = UpdateManager() + update_manager.check_and_update(latest_release, current_version) + except (requests_RequestException, ValueError) as e: + msg = _("Update check failed: {}") + logging_error(msg.format(e)) + + +if __name__ == "__main__": + logging_basicConfig(level=logging_getLevelName("DEBUG"), format="%(asctime)s - %(levelname)s - %(message)s") + logging_warning( + _( + "This main is for testing and development only, usually the check_for_software_updates is" + " called from another script" + ) + ) + check_for_software_updates() diff --git a/pyproject.toml b/pyproject.toml index 901a52a..4e4f0bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "pillow", "setuptools", "requests", + "GitPython", ] dynamic = ["version"] diff --git a/tests/test_annotate_params.py b/tests/test_annotate_params.py index 1cf5ec1..7a1d2a3 100755 --- a/tests/test_annotate_params.py +++ b/tests/test_annotate_params.py @@ -803,9 +803,69 @@ def test_main_oserror(self, mock_file) -> None: @patch("ardupilot_methodic_configurator.annotate_params.get_xml_url") def test_get_xml_url_exception(self, mock_get_xml_url_) -> None: mock_get_xml_url_.side_effect = ValueError("Mocked Value Error") - with pytest.raises(ValueError, match="Vehicle type 'NonExistingVehicle' is not supported."): + with pytest.raises(ValueError, match="Vehicle type 'NonExistingVehicle' is not supported."): # noqa: PT012 get_xml_url("NonExistingVehicle", "4.0") + @patch("requests.get") + def test_get_xml_data_remote_file(mock_get) -> None: + """Test fetching XML data from remote file.""" + # Mock the response + mock_get.return_value.status_code = 200 + mock_get.return_value.text = "" + + # Remove the test.xml file if it exists + with contextlib.suppress(FileNotFoundError): + os.remove("test.xml") + + # Call the function with a remote file + result = get_xml_data("http://example.com/", ".", "test.xml", "ArduCopter") + + # Check the result + assert isinstance(result, ET.Element) + + # Assert that requests.get was called once with correct parameters including proxies + mock_get.assert_called_once_with("http://example.com/test.xml", timeout=5, proxies=None) + + @patch("requests.get") + def test_get_xml_data_remote_file_with_proxies(mock_get) -> None: + """Test fetching XML data with proxy configuration.""" + # Mock environment variables + with patch.dict( + os.environ, + {"HTTP_PROXY": "http://proxy:8080", "HTTPS_PROXY": "https://proxy:8080", "NO_PROXY": "localhost"}, + ): + # Mock the response + mock_get.return_value.status_code = 200 + mock_get.return_value.text = "" + + # Call the function + result = get_xml_data("http://example.com/", ".", "test.xml", "ArduCopter") + + # Check the result + assert isinstance(result, ET.Element) + + # Assert that requests.get was called with proxy settings + expected_proxies = {"http": "http://proxy:8080", "https": "https://proxy:8080", "no_proxy": "localhost"} + mock_get.assert_called_once_with("http://example.com/test.xml", timeout=5, proxies=expected_proxies) + + @patch("requests.get") + def test_get_xml_data_remote_file_no_proxies(mock_get) -> None: + """Test fetching XML data with no proxy configuration.""" + # Clear environment variables + with patch.dict(os.environ, {}, clear=True): + # Mock the response + mock_get.return_value.status_code = 200 + mock_get.return_value.text = "" + + # Call the function + result = get_xml_data("http://example.com/", ".", "test.xml", "ArduCopter") + + # Check the result + assert isinstance(result, ET.Element) + + # Assert that requests.get was called with no proxies + mock_get.assert_called_once_with("http://example.com/test.xml", timeout=5, proxies=None) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_backend_filesystem.py b/tests/test_backend_filesystem.py index cf3fb99..445d345 100755 --- a/tests/test_backend_filesystem.py +++ b/tests/test_backend_filesystem.py @@ -10,11 +10,14 @@ SPDX-License-Identifier: GPL-3.0-or-later """ +import hashlib import unittest +from argparse import ArgumentParser from os import path as os_path +from pathlib import Path from unittest.mock import MagicMock, patch -from requests.exceptions import ConnectTimeout +from git.exc import InvalidGitRepositoryError from ardupilot_methodic_configurator.backend_filesystem import LocalFilesystem @@ -222,99 +225,89 @@ def test_add_configuration_file_to_zip(self) -> None: mock_join.assert_called_once_with("vehicle_dir", "test_file.param") mock_zipfile.write.assert_called_once_with("vehicle_dir/test_file.param", arcname="test_file.param") - @patch("ardupilot_methodic_configurator.backend_filesystem.requests_get") - def test_download_file_from_url(self, mock_get) -> None: - """Test file download functionality with various scenarios.""" - # Setup mock response for successful download - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.content = b"test content" - mock_get.return_value = mock_response - mock_get.side_effect = None # Clear any previous side effects - - # Test successful download - mock_open_obj = MagicMock() - mock_file_obj = MagicMock() - mock_open_obj.return_value.__enter__.return_value = mock_file_obj - - with patch("builtins.open", mock_open_obj): - result = LocalFilesystem.download_file_from_url("http://test.com/file", "local_file.txt") - assert result - mock_get.assert_called_once_with("http://test.com/file", timeout=5) - mock_open_obj.assert_called_once_with("local_file.txt", "wb") - mock_file_obj.write.assert_called_once_with(b"test content") - - # Test failed download (404) - mock_get.reset_mock() - mock_get.side_effect = None # Reset side effect - mock_response.status_code = 404 - result = LocalFilesystem.download_file_from_url("http://test.com/not_found", "local_file.txt") - assert not result - - # Test with empty URL - mock_get.reset_mock() - result = LocalFilesystem.download_file_from_url("", "local_file.txt") - assert not result - mock_get.assert_not_called() - - # Test with empty local filename - mock_get.reset_mock() - result = LocalFilesystem.download_file_from_url("http://test.com/file", "") - assert not result - mock_get.assert_not_called() + def test_get_upload_local_and_remote_filenames_missing_file(self) -> None: + """Test get_upload_local_and_remote_filenames with missing file.""" + lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False) + result = lfs.get_upload_local_and_remote_filenames("missing_file") + assert result == ("", "") - # Test download with connection timeout - mock_get.reset_mock() + def test_get_upload_local_and_remote_filenames_missing_upload_section(self) -> None: + """Test get_upload_local_and_remote_filenames with missing upload section.""" + lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False) + lfs.configuration_steps = {"selected_file": {}} + result = lfs.get_upload_local_and_remote_filenames("selected_file") + assert result == ("", "") - mock_get.side_effect = ConnectTimeout() - # this should be fixed at some point - # result = LocalFilesystem.download_file_from_url("http://test.com/timeout", "local_file.txt") - # assert not result - # mock_get.assert_called_once_with("http://test.com/timeout", timeout=5) + def test_get_download_url_and_local_filename_missing_file(self) -> None: + """Test get_download_url_and_local_filename with missing file.""" + lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False) + result = lfs.get_download_url_and_local_filename("missing_file") + assert result == ("", "") - def test_get_download_url_and_local_filename(self) -> None: + def test_get_download_url_and_local_filename_missing_download_section(self) -> None: + """Test get_download_url_and_local_filename with missing download section.""" lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False) - with patch("os.path.join") as mock_join: - mock_join.return_value = "vehicle_dir/dest_local" - lfs.configuration_steps = { - "selected_file": {"download_file": {"source_url": "http://example.com/file", "dest_local": "dest_local"}} - } - result = lfs.get_download_url_and_local_filename("selected_file") - assert result == ("http://example.com/file", "vehicle_dir/dest_local") - mock_join.assert_called_once_with("vehicle_dir", "dest_local") + lfs.configuration_steps = {"selected_file": {}} + result = lfs.get_download_url_and_local_filename("selected_file") + assert result == ("", "") - def test_get_upload_local_and_remote_filenames(self) -> None: + def test_write_and_read_last_uploaded_filename_error_handling(self) -> None: + """Test error handling in write_and_read_last_uploaded_filename.""" lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False) - with patch("os.path.join") as mock_join: - mock_join.return_value = "vehicle_dir/source_local" - lfs.configuration_steps = { - "selected_file": {"upload_file": {"source_local": "source_local", "dest_on_fc": "dest_on_fc"}} - } - result = lfs.get_upload_local_and_remote_filenames("selected_file") - assert result == ("vehicle_dir/source_local", "dest_on_fc") - mock_join.assert_called_once_with("vehicle_dir", "source_local") + test_filename = "test.param" + + # Test write error + with patch("builtins.open", side_effect=OSError("Write error")): + lfs.write_last_uploaded_filename(test_filename) + # Should not raise exception, just log error + + # Test read error + with patch("builtins.open", side_effect=OSError("Read error")): + result = lfs._LocalFilesystem__read_last_uploaded_filename() # pylint: disable=protected-access + assert result == "" - def test_copy_fc_values_to_file(self) -> None: + def test_copy_fc_values_to_file_with_missing_params(self) -> None: + """Test copy_fc_values_to_file with missing parameters.""" lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False) - test_params = {"PARAM1": 1.0, "PARAM2": 2.0} + test_params = {"PARAM1": 1.0, "PARAM2": 2.0, "PARAM3": 3.0} test_file = "test.param" - # Test with non-existent file - result = lfs.copy_fc_values_to_file(test_file, test_params) - assert result == 0 - - # Test with existing file and matching parameters + # Test with partially matching parameters lfs.file_parameters = {test_file: {"PARAM1": MagicMock(value=0.0), "PARAM2": MagicMock(value=0.0)}} result = lfs.copy_fc_values_to_file(test_file, test_params) assert result == 2 assert lfs.file_parameters[test_file]["PARAM1"].value == 1.0 assert lfs.file_parameters[test_file]["PARAM2"].value == 2.0 - def test_write_and_read_last_uploaded_filename(self) -> None: + def test_get_start_file_empty_files(self) -> None: + """Test get_start_file with empty files list.""" lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False) - test_filename = "test.param" + lfs.file_parameters = {} + result = lfs.get_start_file(1, True) # noqa: FBT003 + assert result == "" + + def test_get_eval_variables_with_none(self) -> None: + """Test get_eval_variables with None values.""" + lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False) + lfs.vehicle_components = None + lfs.doc_dict = None + result = lfs.get_eval_variables() + assert not result + + def test_tolerance_check_with_zero_values(self) -> None: + """Test numerical value comparison with zero values.""" + lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False) + + # Test zero values + x, y = 0.0, 0.0 + assert abs(x - y) <= 1e-08 + (1e-03 * abs(y)) + + # Test small positive vs zero + x, y = 1e-10, 0.0 + assert abs(x - y) <= 1e-08 + (1e-03 * abs(y)) # Test writing + test_filename = "test_param.param" expected_path = os_path.join("vehicle_dir", "last_uploaded_filename.txt") with patch("builtins.open", unittest.mock.mock_open()) as mock_file: lfs.write_last_uploaded_filename(test_filename) @@ -327,6 +320,29 @@ def test_write_and_read_last_uploaded_filename(self) -> None: assert result == test_filename mock_file.assert_called_once_with(expected_path, encoding="utf-8") + def test_tolerance_handling(self) -> None: + """Test parameter value tolerance checking.""" + # Setup LocalFilesystem instance + from ardupilot_methodic_configurator.backend_filesystem import ( # pylint: disable=import-outside-toplevel + is_within_tolerance, + ) + + # Test cases within tolerance (default 0.1%) + assert is_within_tolerance(10.0, 10.009) # +0.09% - should pass + assert is_within_tolerance(10.0, 9.991) # -0.09% - should pass + assert is_within_tolerance(100, 100) # Exact match + assert is_within_tolerance(0.0, 0.0) # Zero case + + # Test cases outside tolerance + assert not is_within_tolerance(10.0, 10.02) # +0.2% - should fail + assert not is_within_tolerance(10.0, 9.98) # -0.2% - should fail + assert not is_within_tolerance(100, 101) # Integer case + + # Test with custom tolerance + custom_tolerance = 0.2 # 0.2% + assert is_within_tolerance(10.0, 10.015, atol=custom_tolerance) # +0.15% - should pass + assert is_within_tolerance(10.0, 9.985, atol=custom_tolerance) # -0.15% - should pass + def test_write_param_default_values(self) -> None: lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False) @@ -503,6 +519,127 @@ def test_all_intermediate_parameter_file_comments(self) -> None: result = lfs._LocalFilesystem__all_intermediate_parameter_file_comments() # pylint: disable=protected-access assert result == {"PARAM1": "Comment 1", "PARAM2": "Override comment 2", "PARAM3": "Comment 3"} + def test_get_git_commit_hash(self) -> None: + # Test with valid git repo + with patch("ardupilot_methodic_configurator.backend_filesystem.Repo") as mock_repo: + mock_repo.return_value.head.object.hexsha = "abcdef1234567890" + result = LocalFilesystem.get_git_commit_hash() + assert result == "abcdef1" + + # Test with no git repo but git_hash.txt exists + with ( + patch("ardupilot_methodic_configurator.backend_filesystem.Repo") as mock_repo, + patch("ardupilot_methodic_configurator.backend_filesystem.os_path.exists") as mock_exists, + patch("builtins.open", unittest.mock.mock_open(read_data="xyz1234")), + ): + mock_repo.side_effect = InvalidGitRepositoryError() + mock_exists.return_value = True + result = LocalFilesystem.get_git_commit_hash() + assert result == "xyz1234" + + # Test with no git repo and no git_hash.txt + with ( + patch("ardupilot_methodic_configurator.backend_filesystem.Repo") as mock_repo, + patch("ardupilot_methodic_configurator.backend_filesystem.os_path.exists") as mock_exists, + ): + mock_repo.side_effect = InvalidGitRepositoryError() + mock_exists.return_value = False + result = LocalFilesystem.get_git_commit_hash() + assert result == "" + + def test_verify_file_hash(self) -> None: + test_content = b"test content" + test_hash = hashlib.sha256(test_content).hexdigest() + mock_file = unittest.mock.mock_open(read_data=test_content) + + with patch("builtins.open", mock_file): + result = LocalFilesystem.verify_file_hash(Path("test_file"), test_hash) + assert result is True + + # Test with incorrect hash + result = LocalFilesystem.verify_file_hash(Path("test_file"), "wrong_hash") + assert result is False + + def test_extend_and_reformat_parameter_documentation_metadata(self) -> None: + lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False) + + test_doc_dict = { + "PARAM1": { + "humanName": "Test Param", + "documentation": ["Test documentation"], + "fields": { + "Units": "m/s (meters per second)", + "Range": "0 100", + "Calibration": "true", + "ReadOnly": "yes", + "RebootRequired": "1", + "Bitmask": "0:Option1, 1:Option2", + }, + "values": {"1": "Value1", "2": "Value2"}, + } + } + + lfs.doc_dict = test_doc_dict + lfs.param_default_dict = {"PARAM1": MagicMock(value=50.0)} + + lfs._LocalFilesystem__extend_and_reformat_parameter_documentation_metadata() # pylint: disable=protected-access + + result = lfs.doc_dict["PARAM1"] + assert result["unit"] == "m/s" + assert result["unit_tooltip"] == "meters per second" + assert result["min"] == 0.0 + assert result["max"] == 100.0 + assert result["Calibration"] is True + assert result["ReadOnly"] is True + assert result["RebootRequired"] is True + assert result["Bitmask"] == {0: "Option1", 1: "Option2"} + assert result["Values"] == {1: "Value1", 2: "Value2"} + assert "Default: 50" in result["doc_tooltip"] + + def test_add_argparse_arguments(self) -> None: + parser = ArgumentParser() + LocalFilesystem.add_argparse_arguments(parser) + + # Verify all expected arguments are added + args = vars(parser.parse_args([])) + assert "vehicle_type" in args + assert "vehicle_dir" in args + assert "n" in args + assert "allow_editing_template_files" in args + + # Test default values + assert args["vehicle_type"] == "" + assert args["n"] == -1 + assert args["allow_editing_template_files"] is False + + # Test with parameters + args = vars(parser.parse_args(["-t", "ArduCopter", "--n", "1", "--allow-editing-template-files"])) + assert args["vehicle_type"] == "ArduCopter" + assert args["n"] == 1 + assert args["allow_editing_template_files"] is True + + def test_annotate_intermediate_comments_to_param_dict(self) -> None: + lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False) + + # Set up mock comments in file_parameters + mock_param1 = MagicMock() + mock_param1.comment = "Comment 1" + mock_param2 = MagicMock() + mock_param2.comment = "Comment 2" + + lfs.file_parameters = {"file1.param": {"PARAM1": mock_param1}, "file2.param": {"PARAM2": mock_param2}} + + input_dict = {"PARAM1": 1.0, "PARAM2": 2.0, "PARAM3": 3.0} + result = lfs.annotate_intermediate_comments_to_param_dict(input_dict) + + assert len(result) == 3 + assert result["PARAM1"].value == 1.0 + assert result["PARAM1"].comment == "Comment 1" + assert result["PARAM2"].value == 2.0 + assert result["PARAM2"].comment == "Comment 2" + assert result["PARAM3"].value == 3.0 + assert result["PARAM3"].comment == "" + class TestCopyTemplateFilesToNewVehicleDir(unittest.TestCase): """Copy Template Files To New Vehicle Directory testclass.""" diff --git a/tests/test_backend_internet.py b/tests/test_backend_internet.py new file mode 100755 index 0000000..ec0d896 --- /dev/null +++ b/tests/test_backend_internet.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 + +""" +Tests for the backend_internet.py file. + +This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator + +SPDX-FileCopyrightText: 2024-2025 Amilcar do Carmo Lucas + +SPDX-License-Identifier: GPL-3.0-or-later +""" + +import os +from unittest.mock import Mock, patch + +import pytest +from requests import HTTPError as requests_HTTPError +from requests import RequestException as requests_RequestException + +from ardupilot_methodic_configurator.backend_internet import ( + download_and_install_on_windows, + download_and_install_pip_release, + download_file_from_url, + get_release_info, +) + + +def test_download_file_from_url_empty_params() -> None: + assert not download_file_from_url("", "") + assert not download_file_from_url("http://test.com", "") + assert not download_file_from_url("", "test.txt") + + +@pytest.mark.parametrize( + "env_vars", + [ + {}, + {"HTTP_PROXY": "http://proxy:8080"}, + {"HTTPS_PROXY": "https://proxy:8080"}, + {"NO_PROXY": "localhost"}, + ], +) +def test_download_file_from_url_proxy_handling(env_vars) -> None: + with patch.dict(os.environ, env_vars, clear=True), patch("requests.get") as mock_get: + mock_get.return_value.status_code = 404 + assert not download_file_from_url("http://test.com", "test.txt") + + +@patch("ardupilot_methodic_configurator.backend_internet.requests_get") +def test_download_file_success(mock_get, tmp_path) -> None: + # Setup mock response + mock_response = Mock() + mock_response.headers = {"content-length": "100"} + mock_response.iter_content.return_value = [b"test data"] + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + test_file = tmp_path / "test.txt" + assert download_file_from_url("http://test.com", str(test_file)) + assert test_file.read_bytes() == b"test data" + + +@pytest.fixture +def mock_get_() -> Mock: + with patch("ardupilot_methodic_configurator.backend_internet.requests_get") as _mock: + yield _mock + + +@patch("ardupilot_methodic_configurator.backend_internet.requests_get") +def test_download_file_invalid_content_length(mock_get) -> None: + # Test handling of invalid content-length header + mock_response = Mock() + mock_response.headers = {"content-length": "invalid"} + mock_response.iter_content.return_value = [b"test data"] + mock_get.return_value = mock_response + assert not download_file_from_url("http://test.com", "test.txt") + + +@patch("ardupilot_methodic_configurator.backend_internet.requests_get") +def test_download_file_missing_content_length(mock_get) -> None: + # Test handling of missing content-length header + mock_response = Mock() + mock_response.headers = {} + mock_response.iter_content.return_value = [b"test data"] + mock_get.return_value = mock_response + assert download_file_from_url("http://test.com", "test.txt") + + +@patch("ardupilot_methodic_configurator.backend_internet.requests_get") +def test_download_file_empty_response(mock_get) -> None: + # Test handling of empty response + mock_response = Mock() + mock_response.headers = {"content-length": "0"} + mock_response.iter_content.return_value = [] + mock_get.return_value = mock_response + assert download_file_from_url("http://test.com", "test.txt") + + +@patch("ardupilot_methodic_configurator.backend_internet.requests_get") +def test_download_file_http_error(mock_get) -> None: + # Test HTTP error handling + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests_HTTPError("404 Not Found") + mock_get.return_value = mock_response + assert not download_file_from_url("http://test.com", "test.txt") + + +@patch("ardupilot_methodic_configurator.backend_internet.requests_get") +def test_download_file_with_progress_no_content_length(mock_get, tmp_path) -> None: + # Test progress callback without content-length header + mock_response = Mock() + mock_response.headers = {} + mock_response.iter_content.return_value = [b"data"] * 4 + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + progress_callback = Mock() + test_file = tmp_path / "test.txt" + + assert download_file_from_url("http://test.com", str(test_file), progress_callback=progress_callback) + assert progress_callback.call_count == 1 # Only final callback + progress_callback.assert_called_with(100.0, "Download complete") + + +@patch("ardupilot_methodic_configurator.backend_internet.requests_get") +def test_download_file_proxy_configuration(mock_get, monkeypatch) -> None: + # Test proxy configuration handling + mock_response = Mock() + mock_response.headers = {"content-length": "100"} + mock_response.iter_content.return_value = [b"test data"] + mock_get.return_value = mock_response + + # Set environment variables + monkeypatch.setenv("HTTP_PROXY", "http://proxy:8080") + monkeypatch.setenv("HTTPS_PROXY", "https://proxy:8080") + monkeypatch.setenv("NO_PROXY", "localhost") + + assert download_file_from_url("http://test.com", "test.txt") + mock_get.assert_called_once_with( + "http://test.com", + stream=True, + timeout=30, + proxies={"http": "http://proxy:8080", "https": "https://proxy:8080", "no_proxy": "localhost"}, + verify=True, + ) + + +@patch("ardupilot_methodic_configurator.backend_internet.requests_get") +def test_download_file_value_error(mock_get) -> None: + # Test handling of ValueError during download + mock_response = Mock() + mock_response.headers = {"content-length": "100"} + mock_response.iter_content.side_effect = ValueError("Invalid data") + mock_get.return_value = mock_response + assert not download_file_from_url("http://test.com", "test.txt") + + +def test_download_file_from_url_invalid_url() -> None: + # Test with invalid URL format + assert not download_file_from_url("not_a_valid_url", "test.txt") + + +@patch("ardupilot_methodic_configurator.backend_internet.requests_get") +def test_download_file_unicode_error(mock_get) -> None: + # Test handling of Unicode decode errors + mock_response = Mock() + mock_response.headers = {"content-length": "100"} + mock_response.iter_content.return_value = [bytes([0xFF, 0xFE, 0xFD])] # Invalid UTF-8 + mock_get.return_value = mock_response + assert download_file_from_url("http://test.com", "test.txt") + + +@patch("ardupilot_methodic_configurator.backend_internet.requests_get") +def test_get_release_info_invalid_release(mock_get) -> None: + mock_get.side_effect = requests_RequestException() + with pytest.raises(requests_RequestException): + get_release_info("latest", False) # noqa: FBT003 + + +@patch("ardupilot_methodic_configurator.backend_internet.requests_get") +def test_get_release_info_prerelease_mismatch(mock_get) -> None: + mock_response = Mock() + mock_response.json.return_value = {"prerelease": True} + mock_get.return_value = mock_response + + +@patch("ardupilot_methodic_configurator.backend_internet.download_file_from_url") +def test_download_and_install_windows_download_failure(mock_download) -> None: + mock_download.return_value = False + assert not download_and_install_on_windows("http://test.com", "test.exe") + + +@patch("os.system") +def test_download_and_install_pip_release(mock_system) -> None: + mock_system.return_value = 0 + assert download_and_install_pip_release() == 0 + + +class TestDownloadFile: + @pytest.fixture + def mock_response(self) -> Mock: + response = Mock() + response.headers = {"content-length": "100"} + response.iter_content.return_value = [b"test data"] + response.raise_for_status.return_value = None + return response + + @pytest.fixture + def mock_get(self) -> Mock: + with patch("ardupilot_methodic_configurator.backend_internet.requests_get") as _mock: + yield _mock + + def test_download_file_network_errors(self, mock_get, caplog) -> None: + errors = [ + requests_HTTPError("404 Not Found"), + requests_RequestException("Connection failed"), + ValueError("Invalid response"), + OSError("File system error"), + ] + + for error in errors: + mock_get.side_effect = error + assert not download_file_from_url("http://test.com", "test.txt") + assert str(error) in caplog.text + caplog.clear() + + def test_download_file_progress_tracking(self, mock_get, mock_response, tmp_path) -> None: + mock_get.return_value = mock_response + progress_values = [] + + def progress_callback(progress: float, msg: str) -> None: + progress_values.append((progress, msg)) + + test_file = tmp_path / "test.txt" + assert download_file_from_url("http://test.com", str(test_file), progress_callback=progress_callback) + + # Verify progress tracking + assert len(progress_values) > 0 + assert progress_values[-1][0] == 100.0 + assert "Download complete" in progress_values[-1][1] + assert test_file.exists() + assert test_file.read_bytes() == b"test data" + + def test_download_file_proxy_configs(self, mock_get, monkeypatch) -> None: + proxy_configs = [ + {"HTTP_PROXY": "http://proxy1:8080"}, + {"HTTPS_PROXY": "https://proxy2:8080"}, + {"HTTP_PROXY": "http://proxy3:8080", "NO_PROXY": "localhost"}, + ] + + mock_response = Mock() + mock_response.headers = {"content-length": "100"} + mock_response.iter_content.return_value = [b"test data"] + mock_get.return_value = mock_response + + for config in proxy_configs: + # Clear previous env vars + monkeypatch.delenv("HTTP_PROXY", raising=False) + monkeypatch.delenv("HTTPS_PROXY", raising=False) + monkeypatch.delenv("NO_PROXY", raising=False) + + # Set new config + for key, value in config.items(): + monkeypatch.setenv(key, value) + + download_file_from_url("http://test.com", "test.txt") + + expected_proxies = {} + if "HTTP_PROXY" in config: + expected_proxies["http"] = config["HTTP_PROXY"] + if "HTTPS_PROXY" in config: + expected_proxies["https"] = config["HTTPS_PROXY"] + if "NO_PROXY" in config: + expected_proxies["no_proxy"] = config["NO_PROXY"] + + mock_get.assert_called_with("http://test.com", stream=True, timeout=30, proxies=expected_proxies, verify=True) + mock_get.reset_mock() + + def test_download_file_filesystem_operations(self, mock_get, mock_response, tmp_path) -> None: + mock_get.return_value = mock_response + + # Test directory creation + nested_path = tmp_path / "deep" / "nested" / "path" + test_file = nested_path / "test.txt" + + assert download_file_from_url("http://test.com", str(test_file)) + assert test_file.exists() + assert test_file.read_bytes() == b"test data" + + # Test file overwrite + assert download_file_from_url("http://test.com", str(test_file)) + assert test_file.exists()