diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml
index 4e3deff..ac31e37 100644
--- a/.github/workflows/pylint.yml
+++ b/.github/workflows/pylint.yml
@@ -37,7 +37,7 @@ jobs:
- name: Install dependencies
# these extra packages are required by pylint to validate the python imports
run: |
- uv pip install 'pylint>=3.3.2' defusedxml requests pymavlink pillow numpy matplotlib pyserial setuptools pytest
+ uv pip install 'pylint>=3.3.2' defusedxml requests pymavlink pillow numpy matplotlib pyserial setuptools pytest GitPython
- name: Analysing the code with pylint
run: |
diff --git a/README.md b/README.md
index 3cafc74..4b54ebe 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,6 @@ SPDX-License-Identifier: GPL-3.0-or-later
| [![md-link-check](https://github.com/ArduPilot/MethodicConfigurator/actions/workflows/markdown-link-check.yml/badge.svg)](https://github.com/ArduPilot/MethodicConfigurator/actions/workflows/markdown-link-check.yml) | | [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?logo=discord&logoColor=white)](https://discord.com/channels/674039678562861068/1308233496535371856) | ![PyPI - Downloads](https://img.shields.io/pypi/dm/ardupilot-methodic-configurator?link=https%3A%2F%2Fpypi.org%2Fproject%2Fardupilot-methodic-configurator%2F)
| |
-
*ArduPilot Methodic Configurator* is a software, developed by ArduPilot developers, that semi-automates a
[clear, proven and safe configuration sequence](https://ardupilot.github.io/MethodicConfigurator/TUNING_GUIDE_ArduCopter) for ArduCopter drones.
We are working on extending it to [ArduPlane](https://ardupilot.github.io/MethodicConfigurator/TUNING_GUIDE_ArduPlane),
diff --git a/ardupilot_methodic_configurator/annotate_params.py b/ardupilot_methodic_configurator/annotate_params.py
index 8d2ac41..5bfb90e 100755
--- a/ardupilot_methodic_configurator/annotate_params.py
+++ b/ardupilot_methodic_configurator/annotate_params.py
@@ -392,7 +392,7 @@ def get_xml_data(base_url: str, directory: str, filename: str, vehicle_type: str
url = base_url + filename
proxies = get_env_proxies()
try:
- response = requests_get(url, timeout=5, proxies=proxies)
+ response = requests_get(url, timeout=5, proxies=proxies) if proxies else requests_get(url, timeout=5)
if response.status_code != 200:
logging.warning("Remote URL: %s", url)
msg = f"HTTP status code {response.status_code}"
diff --git a/ardupilot_methodic_configurator/backend_filesystem.py b/ardupilot_methodic_configurator/backend_filesystem.py
index 92e73f6..d2cb84c 100644
--- a/ardupilot_methodic_configurator/backend_filesystem.py
+++ b/ardupilot_methodic_configurator/backend_filesystem.py
@@ -9,6 +9,7 @@
"""
# from sys import exit as sys_exit
+import hashlib
from argparse import ArgumentParser
from logging import debug as logging_debug
from logging import error as logging_error
@@ -18,6 +19,7 @@
from os import listdir as os_listdir
from os import path as os_path
from os import rename as os_rename
+from pathlib import Path
from platform import system as platform_system
from re import compile as re_compile
from shutil import copy2 as shutil_copy2
@@ -25,14 +27,14 @@
from typing import Any, Optional
from zipfile import ZipFile
-from requests import get as requests_get
+from git import Repo
+from git.exc import InvalidGitRepositoryError
from ardupilot_methodic_configurator import _
from ardupilot_methodic_configurator.annotate_params import (
PARAM_DEFINITION_XML_FILE,
Par,
format_columns,
- get_env_proxies,
get_xml_dir,
get_xml_url,
load_default_param_file,
@@ -639,20 +641,25 @@ def get_upload_local_and_remote_filenames(self, selected_file: str) -> tuple[str
return "", ""
@staticmethod
- def download_file_from_url(url: str, local_filename: str, timeout: int = 5) -> bool:
- if not url or not local_filename:
- logging_error(_("URL or local filename not provided."))
- return False
- logging_info(_("Downloading %s from %s"), local_filename, url)
- response = requests_get(url, timeout=timeout, proxies=get_env_proxies())
-
- if response.status_code == 200:
- with open(local_filename, "wb") as file:
- file.write(response.content)
- return True
+ def get_git_commit_hash() -> str:
+ try:
+ repo = Repo(search_parent_directories=True)
+ return str(repo.head.object.hexsha[:7])
+ except InvalidGitRepositoryError:
+ # Fallback to reading the git_hash.txt file
+ git_hash_file = os_path.join(os_path.dirname(__file__), "git_hash.txt")
+ if os_path.exists(git_hash_file):
+ with open(git_hash_file, encoding="utf-8") as file:
+ return file.read().strip()
+ return ""
- logging_error(_("Failed to download the file"))
- return False
+ @staticmethod
+ def verify_file_hash(file_path: Path, expected_hash: str) -> bool:
+ sha256_hash = hashlib.sha256()
+ with open(file_path, "rb") as f:
+ for byte_block in iter(lambda: f.read(4096), b""):
+ sha256_hash.update(byte_block)
+ return sha256_hash.hexdigest() == expected_hash
@staticmethod
def add_argparse_arguments(parser: ArgumentParser) -> ArgumentParser:
diff --git a/ardupilot_methodic_configurator/backend_internet.py b/ardupilot_methodic_configurator/backend_internet.py
new file mode 100644
index 0000000..cc36572
--- /dev/null
+++ b/ardupilot_methodic_configurator/backend_internet.py
@@ -0,0 +1,195 @@
+"""
+Check for software updates and install them if available.
+
+This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator
+
+SPDX-FileCopyrightText: 2024-2025 Amilcar Lucas
+
+SPDX-License-Identifier: GPL-3.0-or-later
+"""
+
+import os
+import subprocess
+import tempfile
+from datetime import datetime, timezone
+from logging import error as logging_error
+from logging import info as logging_info
+from pathlib import Path
+from typing import Any, Callable, Optional
+from urllib.parse import urljoin
+
+from requests import HTTPError as requests_HTTPError
+from requests import RequestException as requests_RequestException
+from requests import Timeout as requests_Timeout
+from requests import get as requests_get
+from requests.exceptions import RequestException
+
+from ardupilot_methodic_configurator import _
+from ardupilot_methodic_configurator.backend_filesystem import LocalFilesystem
+
+# Constants
+GITHUB_API_URL_RELEASES = "https://api.github.com/repos/ArduPilot/MethodicConfigurator/releases/"
+
+
+def download_file_from_url(
+ url: str, local_filename: str, timeout: int = 30, progress_callback: Optional[Callable[[float, str], None]] = None
+) -> bool:
+ if not url or not local_filename:
+ logging_error(_("URL or local filename not provided."))
+ return False
+
+ logging_info(_("Downloading %s from %s"), local_filename, url)
+
+ try:
+ proxies_dict = {
+ "http": os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy"),
+ "https": os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy"),
+ "no_proxy": os.environ.get("NO_PROXY") or os.environ.get("no_proxy"),
+ }
+
+ # Remove None values
+ proxies = {k: v for k, v in proxies_dict.items() if v is not None}
+
+ # Make request with proxy support
+ response = requests_get(
+ url,
+ stream=True,
+ timeout=timeout,
+ proxies=proxies,
+ verify=True, # SSL verification
+ )
+ response.raise_for_status()
+
+ total_size = int(response.headers.get("content-length", 0))
+ block_size = 8192
+ downloaded = 0
+
+ os.makedirs(os.path.dirname(os.path.abspath(local_filename)), exist_ok=True)
+
+ with open(local_filename, "wb") as file:
+ for chunk in response.iter_content(chunk_size=block_size):
+ if chunk:
+ file.write(chunk)
+ downloaded += len(chunk)
+ if progress_callback and total_size:
+ progress = (downloaded / total_size) * 100
+ msg = _("Downloading ... {:.1f}%")
+ progress_callback(progress, msg.format(progress))
+
+ if progress_callback:
+ progress_callback(100.0, _("Download complete"))
+ return True
+
+ except requests_Timeout:
+ logging_error(_("Download timed out"))
+ except requests_RequestException as e:
+ logging_error(_("Network error during download: {}").format(e))
+ except OSError as e:
+ logging_error(_("File system error: {}").format(e))
+ except ValueError as e:
+ logging_error(_("Invalid data received from %s: %s"), url, e)
+
+ return False
+
+
+def get_release_info(name: str, should_be_pre_release: bool, timeout: int = 30) -> dict[str, Any]:
+ """
+ Get release information from GitHub API.
+
+ Args:
+ name: Release name/path (e.g. '/latest')
+ should_be_pre_release: Whether the release should be a pre-release
+ timeout: Request timeout in seconds
+
+ Returns:
+ Release information dictionary
+
+ Raises:
+ RequestException: If the request fails
+
+ """
+ if not name:
+ msg = "Release name cannot be empty"
+ raise ValueError(msg)
+
+ try:
+ url = urljoin(GITHUB_API_URL_RELEASES, name.lstrip("/"))
+ response = requests_get(url, timeout=timeout)
+ response.raise_for_status()
+
+ release_info = response.json()
+
+ if should_be_pre_release and not release_info["prerelease"]:
+ logging_error(_("The latest continuous delivery build must be a pre-release"))
+ if not should_be_pre_release and release_info["prerelease"]:
+ logging_error(_("The latest stable release must not be a pre-release"))
+
+ return release_info # type: ignore[no-any-return]
+
+ except requests_HTTPError as e:
+ if e.response.status_code == 403:
+ logging_error(_("Failed to fetch release info: {}").format(e))
+ # Get the rate limit reset time
+ reset_time = int(e.response.headers.get("X-RateLimit-Reset", 0))
+ # Create a timezone-aware UTC datetime
+ reset_datetime = datetime.fromtimestamp(reset_time, timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
+ logging_error(_("Rate limit exceeded. Please try again after: %s (UTC)"), reset_datetime)
+ raise
+ except RequestException as e:
+ logging_error(_("Failed to fetch release info: {}").format(e))
+ raise
+ except (KeyError, ValueError) as e:
+ logging_error(_("Invalid release data: {}").format(e))
+ raise
+
+
+def download_and_install_on_windows(
+ download_url: str,
+ file_name: str,
+ expected_hash: Optional[str] = None,
+ progress_callback: Optional[Callable[[float, str], None]] = None,
+) -> bool:
+ logging_info(_("Downloading and installing new version for Windows..."))
+ try:
+ with tempfile.TemporaryDirectory() as temp_dir:
+ temp_path = os.path.join(temp_dir, file_name)
+
+ # Download with progress updates
+ if not download_file_from_url(
+ download_url,
+ temp_path,
+ timeout=60, # Increased timeout for large files
+ progress_callback=progress_callback,
+ ):
+ return False
+
+ if expected_hash and not LocalFilesystem.verify_file_hash(Path(temp_path), expected_hash):
+ logging_error(_("File hash verification failed"))
+ return False
+
+ if progress_callback:
+ progress_callback(100.0, _("Starting installation..."))
+
+ # Run installer
+ result = subprocess.run( # noqa: S603
+ [temp_path],
+ shell=False,
+ check=True,
+ capture_output=True,
+ text=True,
+ creationflags=subprocess.CREATE_NO_WINDOW, # type: ignore[attr-defined]
+ )
+
+ return result.returncode == 0
+
+ except subprocess.SubprocessError as e:
+ logging_error(_("Installation failed: {}").format(e))
+ return False
+ except OSError as e:
+ logging_error(_("File operation failed: {}").format(e))
+ return False
+
+
+def download_and_install_pip_release() -> int:
+ logging_info(_("Updating via pip for Linux and MacOS..."))
+ return os.system("pip install --upgrade ardupilot_methodic_configurator") # noqa: S605, S607
diff --git a/ardupilot_methodic_configurator/frontend_tkinter_parameter_editor.py b/ardupilot_methodic_configurator/frontend_tkinter_parameter_editor.py
index a6eabdb..f4de1ad 100755
--- a/ardupilot_methodic_configurator/frontend_tkinter_parameter_editor.py
+++ b/ardupilot_methodic_configurator/frontend_tkinter_parameter_editor.py
@@ -30,6 +30,7 @@
from ardupilot_methodic_configurator.backend_filesystem import LocalFilesystem, is_within_tolerance
from ardupilot_methodic_configurator.backend_filesystem_program_settings import ProgramSettings
from ardupilot_methodic_configurator.backend_flightcontroller import FlightController
+from ardupilot_methodic_configurator.backend_internet import download_file_from_url
from ardupilot_methodic_configurator.common_arguments import add_common_arguments_and_parse
from ardupilot_methodic_configurator.frontend_tkinter_base import (
AutoResizeCombobox,
@@ -459,9 +460,9 @@ def __should_download_file_from_url(self, selected_file: str) -> None:
if self.local_filesystem.vehicle_configuration_file_exists(local_filename):
return # file already exists in the vehicle directory, no need to download it
msg = _("Should the {local_filename} file be downloaded from the URL\n{url}?")
- if messagebox.askyesno(
- _("Download file from URL"), msg.format(**locals())
- ) and not self.local_filesystem.download_file_from_url(url, local_filename):
+ if messagebox.askyesno(_("Download file from URL"), msg.format(**locals())) and not download_file_from_url(
+ url, local_filename
+ ):
error_msg = _("Failed to download {local_filename} from {url}, please download it manually")
messagebox.showerror(_("Download failed"), error_msg.format(**locals()))
diff --git a/ardupilot_methodic_configurator/frontend_tkinter_software_update.py b/ardupilot_methodic_configurator/frontend_tkinter_software_update.py
new file mode 100644
index 0000000..390b98f
--- /dev/null
+++ b/ardupilot_methodic_configurator/frontend_tkinter_software_update.py
@@ -0,0 +1,91 @@
+"""
+Check for software updates and install them if available.
+
+This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator
+
+SPDX-FileCopyrightText: 2024-2025 Amilcar Lucas
+
+SPDX-License-Identifier: GPL-3.0-or-later
+"""
+
+import tkinter as tk
+from tkinter import ttk
+from typing import Callable, Optional
+
+from ardupilot_methodic_configurator import _, __version__
+
+
+class UpdateDialog: # pylint: disable=too-many-instance-attributes
+ """Dialog for displaying software update information and handling user interaction."""
+
+ def __init__(self, version_info: str, download_callback: Optional[Callable[[], bool]] = None) -> None:
+ self.root = tk.Tk()
+ self.root.title(_("Amilcar Lucas's - ArduPilot methodic configurator ") + __version__ + _(" - Update Software"))
+ self.download_callback = download_callback
+ self.root.protocol("WM_DELETE_WINDOW", self.on_cancel)
+
+ self.frame = ttk.Frame(self.root, padding="20")
+ self.frame.grid(sticky="nsew")
+
+ self.msg = ttk.Label(self.frame, text=version_info, wraplength=650, justify="left")
+ self.msg.grid(row=0, column=0, columnspan=2, pady=20)
+
+ self.progress = ttk.Progressbar(self.frame, orient="horizontal", length=400, mode="determinate")
+ self.progress.grid(row=1, column=0, columnspan=2, pady=10, padx=10)
+ self.progress.grid_remove()
+
+ self.status_label = ttk.Label(self.frame, text="")
+ self.status_label.grid(row=2, column=0, columnspan=2)
+
+ self.result: Optional[bool] = None
+ self._setup_buttons()
+
+ def _setup_buttons(self) -> None:
+ self.yes_btn = ttk.Button(self.frame, text=_("Update Now"), command=self.on_yes)
+ self.no_btn = ttk.Button(self.frame, text=_("Not Now"), command=self.on_no)
+ self.yes_btn.grid(row=3, column=0, padx=5)
+ self.no_btn.grid(row=3, column=1, padx=5)
+
+ def update_progress(self, value: float, status: str = "") -> None:
+ """Update progress directly."""
+ self.progress["value"] = value
+ if status:
+ self.status_label["text"] = status
+ self.root.update()
+
+ def on_yes(self) -> None:
+ self.progress.grid()
+ self.status_label.grid()
+ self.yes_btn.config(state="disabled")
+ self.no_btn.config(state="disabled")
+
+ if self.download_callback:
+ success = self.download_callback()
+ if success:
+ self.status_label["text"] = _("Update complete! Please restart the application.")
+ self.result = True
+ else:
+ self.status_label["text"] = _("Update failed!")
+ self.yes_btn.config(state="normal")
+ self.no_btn.config(state="normal")
+ self.result = False
+ self.root.after(2000, self.root.destroy)
+
+ def on_no(self) -> None:
+ self.result = False
+ self.root.destroy()
+
+ def on_cancel(self) -> None:
+ self.result = False
+ self.root.destroy()
+
+ def show(self) -> bool:
+ """
+ Display the update dialog and return user's choice.
+
+ Returns:
+ bool: True if user chose to update, False otherwise
+
+ """
+ self.root.mainloop()
+ return bool(self.result)
diff --git a/ardupilot_methodic_configurator/middleware_software_updates.py b/ardupilot_methodic_configurator/middleware_software_updates.py
new file mode 100755
index 0000000..d5d85ba
--- /dev/null
+++ b/ardupilot_methodic_configurator/middleware_software_updates.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python3
+
+"""
+Check for software updates and install them if available.
+
+This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator
+
+SPDX-FileCopyrightText: 2024-2025 Amilcar Lucas
+
+SPDX-License-Identifier: GPL-3.0-or-later
+"""
+
+import platform
+from logging import basicConfig as logging_basicConfig
+from logging import debug as logging_error
+from logging import getLevelName as logging_getLevelName
+from logging import info as logging_info
+from logging import warning as logging_warning
+from typing import Any, Optional
+
+from requests import RequestException as requests_RequestException
+
+from ardupilot_methodic_configurator import _
+from ardupilot_methodic_configurator import __version__ as current_version
+from ardupilot_methodic_configurator.backend_filesystem import LocalFilesystem
+from ardupilot_methodic_configurator.backend_internet import (
+ download_and_install_on_windows,
+ download_and_install_pip_release,
+ get_release_info,
+)
+from ardupilot_methodic_configurator.frontend_tkinter_software_update import UpdateDialog
+
+
+def format_version_info(_current_version: str, _latest_release: str, _changes: str) -> str:
+ return (
+ _("New version available!")
+ + "\n\n"
+ + _("Current version: {_current_version}")
+ + "\n"
+ + _("Latest version: {_latest_release}")
+ + "\n\n"
+ + _("Changes:\n{_changes}")
+ ).format(**locals())
+
+
+class UpdateManager: # pylint: disable=too-few-public-methods
+ """Manages the software update process including user interaction and installation."""
+
+ def __init__(self) -> None:
+ self.dialog: Optional[UpdateDialog] = None
+
+ def _perform_download(self, latest_release: dict[str, Any]) -> bool:
+ if platform.system() == "Windows":
+ try:
+ asset = latest_release["assets"][0]
+ return download_and_install_on_windows(
+ download_url=asset["browser_download_url"],
+ file_name=asset["name"],
+ progress_callback=self.dialog.update_progress if self.dialog else None,
+ )
+ except (KeyError, IndexError) as e:
+ logging_error(_("Error accessing release assets: %s"), e)
+ return False
+ return download_and_install_pip_release() == 0
+
+ def check_and_update(self, latest_release: dict[str, Any], current_version_str: str) -> bool:
+ try:
+ latest_version = latest_release["tag_name"].lstrip("v")
+ if current_version_str == latest_version:
+ logging_info(_("Already running latest version."))
+ return True
+
+ version_info = format_version_info(
+ current_version_str, latest_version, latest_release.get("body", _("No changes listed"))
+ )
+ self.dialog = UpdateDialog(version_info, download_callback=lambda: self._perform_download(latest_release))
+ return self.dialog.show()
+
+ except KeyError as ke:
+ logging_error(_("Key error during update process: %s"), ke)
+ return False
+ except requests_RequestException as req_ex:
+ logging_error(_("Network error during update process: %s"), req_ex)
+ return False
+ except ValueError as val_ex:
+ logging_error(_("Value error during update process: %s"), val_ex)
+ return False
+
+
+def check_for_software_updates() -> None:
+ """Main update orchestration function."""
+ git_hash = LocalFilesystem.get_git_commit_hash()
+
+ msg = _("Running version: {} (git hash: {})")
+ logging_info(msg.format(current_version, git_hash))
+
+ try:
+ latest_release = get_release_info("/latest", should_be_pre_release=False)
+ update_manager = UpdateManager()
+ update_manager.check_and_update(latest_release, current_version)
+ except (requests_RequestException, ValueError) as e:
+ msg = _("Update check failed: {}")
+ logging_error(msg.format(e))
+
+
+if __name__ == "__main__":
+ logging_basicConfig(level=logging_getLevelName("DEBUG"), format="%(asctime)s - %(levelname)s - %(message)s")
+ logging_warning(
+ _(
+ "This main is for testing and development only, usually the check_for_software_updates is"
+ " called from another script"
+ )
+ )
+ check_for_software_updates()
diff --git a/pyproject.toml b/pyproject.toml
index 901a52a..4e4f0bf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -50,6 +50,7 @@ dependencies = [
"pillow",
"setuptools",
"requests",
+ "GitPython",
]
dynamic = ["version"]
diff --git a/tests/test_annotate_params.py b/tests/test_annotate_params.py
index 1cf5ec1..7a1d2a3 100755
--- a/tests/test_annotate_params.py
+++ b/tests/test_annotate_params.py
@@ -803,9 +803,69 @@ def test_main_oserror(self, mock_file) -> None:
@patch("ardupilot_methodic_configurator.annotate_params.get_xml_url")
def test_get_xml_url_exception(self, mock_get_xml_url_) -> None:
mock_get_xml_url_.side_effect = ValueError("Mocked Value Error")
- with pytest.raises(ValueError, match="Vehicle type 'NonExistingVehicle' is not supported."):
+ with pytest.raises(ValueError, match="Vehicle type 'NonExistingVehicle' is not supported."): # noqa: PT012
get_xml_url("NonExistingVehicle", "4.0")
+ @patch("requests.get")
+ def test_get_xml_data_remote_file(mock_get) -> None:
+ """Test fetching XML data from remote file."""
+ # Mock the response
+ mock_get.return_value.status_code = 200
+ mock_get.return_value.text = ""
+
+ # Remove the test.xml file if it exists
+ with contextlib.suppress(FileNotFoundError):
+ os.remove("test.xml")
+
+ # Call the function with a remote file
+ result = get_xml_data("http://example.com/", ".", "test.xml", "ArduCopter")
+
+ # Check the result
+ assert isinstance(result, ET.Element)
+
+ # Assert that requests.get was called once with correct parameters including proxies
+ mock_get.assert_called_once_with("http://example.com/test.xml", timeout=5, proxies=None)
+
+ @patch("requests.get")
+ def test_get_xml_data_remote_file_with_proxies(mock_get) -> None:
+ """Test fetching XML data with proxy configuration."""
+ # Mock environment variables
+ with patch.dict(
+ os.environ,
+ {"HTTP_PROXY": "http://proxy:8080", "HTTPS_PROXY": "https://proxy:8080", "NO_PROXY": "localhost"},
+ ):
+ # Mock the response
+ mock_get.return_value.status_code = 200
+ mock_get.return_value.text = ""
+
+ # Call the function
+ result = get_xml_data("http://example.com/", ".", "test.xml", "ArduCopter")
+
+ # Check the result
+ assert isinstance(result, ET.Element)
+
+ # Assert that requests.get was called with proxy settings
+ expected_proxies = {"http": "http://proxy:8080", "https": "https://proxy:8080", "no_proxy": "localhost"}
+ mock_get.assert_called_once_with("http://example.com/test.xml", timeout=5, proxies=expected_proxies)
+
+ @patch("requests.get")
+ def test_get_xml_data_remote_file_no_proxies(mock_get) -> None:
+ """Test fetching XML data with no proxy configuration."""
+ # Clear environment variables
+ with patch.dict(os.environ, {}, clear=True):
+ # Mock the response
+ mock_get.return_value.status_code = 200
+ mock_get.return_value.text = ""
+
+ # Call the function
+ result = get_xml_data("http://example.com/", ".", "test.xml", "ArduCopter")
+
+ # Check the result
+ assert isinstance(result, ET.Element)
+
+ # Assert that requests.get was called with no proxies
+ mock_get.assert_called_once_with("http://example.com/test.xml", timeout=5, proxies=None)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_backend_filesystem.py b/tests/test_backend_filesystem.py
index cf3fb99..445d345 100755
--- a/tests/test_backend_filesystem.py
+++ b/tests/test_backend_filesystem.py
@@ -10,11 +10,14 @@
SPDX-License-Identifier: GPL-3.0-or-later
"""
+import hashlib
import unittest
+from argparse import ArgumentParser
from os import path as os_path
+from pathlib import Path
from unittest.mock import MagicMock, patch
-from requests.exceptions import ConnectTimeout
+from git.exc import InvalidGitRepositoryError
from ardupilot_methodic_configurator.backend_filesystem import LocalFilesystem
@@ -222,99 +225,89 @@ def test_add_configuration_file_to_zip(self) -> None:
mock_join.assert_called_once_with("vehicle_dir", "test_file.param")
mock_zipfile.write.assert_called_once_with("vehicle_dir/test_file.param", arcname="test_file.param")
- @patch("ardupilot_methodic_configurator.backend_filesystem.requests_get")
- def test_download_file_from_url(self, mock_get) -> None:
- """Test file download functionality with various scenarios."""
- # Setup mock response for successful download
- mock_response = MagicMock()
- mock_response.status_code = 200
- mock_response.content = b"test content"
- mock_get.return_value = mock_response
- mock_get.side_effect = None # Clear any previous side effects
-
- # Test successful download
- mock_open_obj = MagicMock()
- mock_file_obj = MagicMock()
- mock_open_obj.return_value.__enter__.return_value = mock_file_obj
-
- with patch("builtins.open", mock_open_obj):
- result = LocalFilesystem.download_file_from_url("http://test.com/file", "local_file.txt")
- assert result
- mock_get.assert_called_once_with("http://test.com/file", timeout=5)
- mock_open_obj.assert_called_once_with("local_file.txt", "wb")
- mock_file_obj.write.assert_called_once_with(b"test content")
-
- # Test failed download (404)
- mock_get.reset_mock()
- mock_get.side_effect = None # Reset side effect
- mock_response.status_code = 404
- result = LocalFilesystem.download_file_from_url("http://test.com/not_found", "local_file.txt")
- assert not result
-
- # Test with empty URL
- mock_get.reset_mock()
- result = LocalFilesystem.download_file_from_url("", "local_file.txt")
- assert not result
- mock_get.assert_not_called()
-
- # Test with empty local filename
- mock_get.reset_mock()
- result = LocalFilesystem.download_file_from_url("http://test.com/file", "")
- assert not result
- mock_get.assert_not_called()
+ def test_get_upload_local_and_remote_filenames_missing_file(self) -> None:
+ """Test get_upload_local_and_remote_filenames with missing file."""
+ lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False)
+ result = lfs.get_upload_local_and_remote_filenames("missing_file")
+ assert result == ("", "")
- # Test download with connection timeout
- mock_get.reset_mock()
+ def test_get_upload_local_and_remote_filenames_missing_upload_section(self) -> None:
+ """Test get_upload_local_and_remote_filenames with missing upload section."""
+ lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False)
+ lfs.configuration_steps = {"selected_file": {}}
+ result = lfs.get_upload_local_and_remote_filenames("selected_file")
+ assert result == ("", "")
- mock_get.side_effect = ConnectTimeout()
- # this should be fixed at some point
- # result = LocalFilesystem.download_file_from_url("http://test.com/timeout", "local_file.txt")
- # assert not result
- # mock_get.assert_called_once_with("http://test.com/timeout", timeout=5)
+ def test_get_download_url_and_local_filename_missing_file(self) -> None:
+ """Test get_download_url_and_local_filename with missing file."""
+ lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False)
+ result = lfs.get_download_url_and_local_filename("missing_file")
+ assert result == ("", "")
- def test_get_download_url_and_local_filename(self) -> None:
+ def test_get_download_url_and_local_filename_missing_download_section(self) -> None:
+ """Test get_download_url_and_local_filename with missing download section."""
lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False)
- with patch("os.path.join") as mock_join:
- mock_join.return_value = "vehicle_dir/dest_local"
- lfs.configuration_steps = {
- "selected_file": {"download_file": {"source_url": "http://example.com/file", "dest_local": "dest_local"}}
- }
- result = lfs.get_download_url_and_local_filename("selected_file")
- assert result == ("http://example.com/file", "vehicle_dir/dest_local")
- mock_join.assert_called_once_with("vehicle_dir", "dest_local")
+ lfs.configuration_steps = {"selected_file": {}}
+ result = lfs.get_download_url_and_local_filename("selected_file")
+ assert result == ("", "")
- def test_get_upload_local_and_remote_filenames(self) -> None:
+ def test_write_and_read_last_uploaded_filename_error_handling(self) -> None:
+ """Test error handling in write_and_read_last_uploaded_filename."""
lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False)
- with patch("os.path.join") as mock_join:
- mock_join.return_value = "vehicle_dir/source_local"
- lfs.configuration_steps = {
- "selected_file": {"upload_file": {"source_local": "source_local", "dest_on_fc": "dest_on_fc"}}
- }
- result = lfs.get_upload_local_and_remote_filenames("selected_file")
- assert result == ("vehicle_dir/source_local", "dest_on_fc")
- mock_join.assert_called_once_with("vehicle_dir", "source_local")
+ test_filename = "test.param"
+
+ # Test write error
+ with patch("builtins.open", side_effect=OSError("Write error")):
+ lfs.write_last_uploaded_filename(test_filename)
+ # Should not raise exception, just log error
+
+ # Test read error
+ with patch("builtins.open", side_effect=OSError("Read error")):
+ result = lfs._LocalFilesystem__read_last_uploaded_filename() # pylint: disable=protected-access
+ assert result == ""
- def test_copy_fc_values_to_file(self) -> None:
+ def test_copy_fc_values_to_file_with_missing_params(self) -> None:
+ """Test copy_fc_values_to_file with missing parameters."""
lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False)
- test_params = {"PARAM1": 1.0, "PARAM2": 2.0}
+ test_params = {"PARAM1": 1.0, "PARAM2": 2.0, "PARAM3": 3.0}
test_file = "test.param"
- # Test with non-existent file
- result = lfs.copy_fc_values_to_file(test_file, test_params)
- assert result == 0
-
- # Test with existing file and matching parameters
+ # Test with partially matching parameters
lfs.file_parameters = {test_file: {"PARAM1": MagicMock(value=0.0), "PARAM2": MagicMock(value=0.0)}}
result = lfs.copy_fc_values_to_file(test_file, test_params)
assert result == 2
assert lfs.file_parameters[test_file]["PARAM1"].value == 1.0
assert lfs.file_parameters[test_file]["PARAM2"].value == 2.0
- def test_write_and_read_last_uploaded_filename(self) -> None:
+ def test_get_start_file_empty_files(self) -> None:
+ """Test get_start_file with empty files list."""
lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False)
- test_filename = "test.param"
+ lfs.file_parameters = {}
+ result = lfs.get_start_file(1, True) # noqa: FBT003
+ assert result == ""
+
+ def test_get_eval_variables_with_none(self) -> None:
+ """Test get_eval_variables with None values."""
+ lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False)
+ lfs.vehicle_components = None
+ lfs.doc_dict = None
+ result = lfs.get_eval_variables()
+ assert not result
+
+ def test_tolerance_check_with_zero_values(self) -> None:
+ """Test numerical value comparison with zero values."""
+ lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False)
+
+ # Test zero values
+ x, y = 0.0, 0.0
+ assert abs(x - y) <= 1e-08 + (1e-03 * abs(y))
+
+ # Test small positive vs zero
+ x, y = 1e-10, 0.0
+ assert abs(x - y) <= 1e-08 + (1e-03 * abs(y))
# Test writing
+ test_filename = "test_param.param"
expected_path = os_path.join("vehicle_dir", "last_uploaded_filename.txt")
with patch("builtins.open", unittest.mock.mock_open()) as mock_file:
lfs.write_last_uploaded_filename(test_filename)
@@ -327,6 +320,29 @@ def test_write_and_read_last_uploaded_filename(self) -> None:
assert result == test_filename
mock_file.assert_called_once_with(expected_path, encoding="utf-8")
+ def test_tolerance_handling(self) -> None:
+ """Test parameter value tolerance checking."""
+ # Setup LocalFilesystem instance
+ from ardupilot_methodic_configurator.backend_filesystem import ( # pylint: disable=import-outside-toplevel
+ is_within_tolerance,
+ )
+
+ # Test cases within tolerance (default 0.1%)
+ assert is_within_tolerance(10.0, 10.009) # +0.09% - should pass
+ assert is_within_tolerance(10.0, 9.991) # -0.09% - should pass
+ assert is_within_tolerance(100, 100) # Exact match
+ assert is_within_tolerance(0.0, 0.0) # Zero case
+
+ # Test cases outside tolerance
+ assert not is_within_tolerance(10.0, 10.02) # +0.2% - should fail
+ assert not is_within_tolerance(10.0, 9.98) # -0.2% - should fail
+ assert not is_within_tolerance(100, 101) # Integer case
+
+ # Test with custom tolerance
+ custom_tolerance = 0.2 # 0.2%
+ assert is_within_tolerance(10.0, 10.015, atol=custom_tolerance) # +0.15% - should pass
+ assert is_within_tolerance(10.0, 9.985, atol=custom_tolerance) # -0.15% - should pass
+
def test_write_param_default_values(self) -> None:
lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False)
@@ -503,6 +519,127 @@ def test_all_intermediate_parameter_file_comments(self) -> None:
result = lfs._LocalFilesystem__all_intermediate_parameter_file_comments() # pylint: disable=protected-access
assert result == {"PARAM1": "Comment 1", "PARAM2": "Override comment 2", "PARAM3": "Comment 3"}
+ def test_get_git_commit_hash(self) -> None:
+ # Test with valid git repo
+ with patch("ardupilot_methodic_configurator.backend_filesystem.Repo") as mock_repo:
+ mock_repo.return_value.head.object.hexsha = "abcdef1234567890"
+ result = LocalFilesystem.get_git_commit_hash()
+ assert result == "abcdef1"
+
+ # Test with no git repo but git_hash.txt exists
+ with (
+ patch("ardupilot_methodic_configurator.backend_filesystem.Repo") as mock_repo,
+ patch("ardupilot_methodic_configurator.backend_filesystem.os_path.exists") as mock_exists,
+ patch("builtins.open", unittest.mock.mock_open(read_data="xyz1234")),
+ ):
+ mock_repo.side_effect = InvalidGitRepositoryError()
+ mock_exists.return_value = True
+ result = LocalFilesystem.get_git_commit_hash()
+ assert result == "xyz1234"
+
+ # Test with no git repo and no git_hash.txt
+ with (
+ patch("ardupilot_methodic_configurator.backend_filesystem.Repo") as mock_repo,
+ patch("ardupilot_methodic_configurator.backend_filesystem.os_path.exists") as mock_exists,
+ ):
+ mock_repo.side_effect = InvalidGitRepositoryError()
+ mock_exists.return_value = False
+ result = LocalFilesystem.get_git_commit_hash()
+ assert result == ""
+
+ def test_verify_file_hash(self) -> None:
+ test_content = b"test content"
+ test_hash = hashlib.sha256(test_content).hexdigest()
+ mock_file = unittest.mock.mock_open(read_data=test_content)
+
+ with patch("builtins.open", mock_file):
+ result = LocalFilesystem.verify_file_hash(Path("test_file"), test_hash)
+ assert result is True
+
+ # Test with incorrect hash
+ result = LocalFilesystem.verify_file_hash(Path("test_file"), "wrong_hash")
+ assert result is False
+
+ def test_extend_and_reformat_parameter_documentation_metadata(self) -> None:
+ lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False)
+
+ test_doc_dict = {
+ "PARAM1": {
+ "humanName": "Test Param",
+ "documentation": ["Test documentation"],
+ "fields": {
+ "Units": "m/s (meters per second)",
+ "Range": "0 100",
+ "Calibration": "true",
+ "ReadOnly": "yes",
+ "RebootRequired": "1",
+ "Bitmask": "0:Option1, 1:Option2",
+ },
+ "values": {"1": "Value1", "2": "Value2"},
+ }
+ }
+
+ lfs.doc_dict = test_doc_dict
+ lfs.param_default_dict = {"PARAM1": MagicMock(value=50.0)}
+
+ lfs._LocalFilesystem__extend_and_reformat_parameter_documentation_metadata() # pylint: disable=protected-access
+
+ result = lfs.doc_dict["PARAM1"]
+ assert result["unit"] == "m/s"
+ assert result["unit_tooltip"] == "meters per second"
+ assert result["min"] == 0.0
+ assert result["max"] == 100.0
+ assert result["Calibration"] is True
+ assert result["ReadOnly"] is True
+ assert result["RebootRequired"] is True
+ assert result["Bitmask"] == {0: "Option1", 1: "Option2"}
+ assert result["Values"] == {1: "Value1", 2: "Value2"}
+ assert "Default: 50" in result["doc_tooltip"]
+
+ def test_add_argparse_arguments(self) -> None:
+ parser = ArgumentParser()
+ LocalFilesystem.add_argparse_arguments(parser)
+
+ # Verify all expected arguments are added
+ args = vars(parser.parse_args([]))
+ assert "vehicle_type" in args
+ assert "vehicle_dir" in args
+ assert "n" in args
+ assert "allow_editing_template_files" in args
+
+ # Test default values
+ assert args["vehicle_type"] == ""
+ assert args["n"] == -1
+ assert args["allow_editing_template_files"] is False
+
+ # Test with parameters
+ args = vars(parser.parse_args(["-t", "ArduCopter", "--n", "1", "--allow-editing-template-files"]))
+ assert args["vehicle_type"] == "ArduCopter"
+ assert args["n"] == 1
+ assert args["allow_editing_template_files"] is True
+
+ def test_annotate_intermediate_comments_to_param_dict(self) -> None:
+ lfs = LocalFilesystem("vehicle_dir", "vehicle_type", None, allow_editing_template_files=False)
+
+ # Set up mock comments in file_parameters
+ mock_param1 = MagicMock()
+ mock_param1.comment = "Comment 1"
+ mock_param2 = MagicMock()
+ mock_param2.comment = "Comment 2"
+
+ lfs.file_parameters = {"file1.param": {"PARAM1": mock_param1}, "file2.param": {"PARAM2": mock_param2}}
+
+ input_dict = {"PARAM1": 1.0, "PARAM2": 2.0, "PARAM3": 3.0}
+ result = lfs.annotate_intermediate_comments_to_param_dict(input_dict)
+
+ assert len(result) == 3
+ assert result["PARAM1"].value == 1.0
+ assert result["PARAM1"].comment == "Comment 1"
+ assert result["PARAM2"].value == 2.0
+ assert result["PARAM2"].comment == "Comment 2"
+ assert result["PARAM3"].value == 3.0
+ assert result["PARAM3"].comment == ""
+
class TestCopyTemplateFilesToNewVehicleDir(unittest.TestCase):
"""Copy Template Files To New Vehicle Directory testclass."""
diff --git a/tests/test_backend_internet.py b/tests/test_backend_internet.py
new file mode 100755
index 0000000..ec0d896
--- /dev/null
+++ b/tests/test_backend_internet.py
@@ -0,0 +1,292 @@
+#!/usr/bin/env python3
+
+"""
+Tests for the backend_internet.py file.
+
+This file is part of Ardupilot methodic configurator. https://github.com/ArduPilot/MethodicConfigurator
+
+SPDX-FileCopyrightText: 2024-2025 Amilcar do Carmo Lucas
+
+SPDX-License-Identifier: GPL-3.0-or-later
+"""
+
+import os
+from unittest.mock import Mock, patch
+
+import pytest
+from requests import HTTPError as requests_HTTPError
+from requests import RequestException as requests_RequestException
+
+from ardupilot_methodic_configurator.backend_internet import (
+ download_and_install_on_windows,
+ download_and_install_pip_release,
+ download_file_from_url,
+ get_release_info,
+)
+
+
+def test_download_file_from_url_empty_params() -> None:
+ assert not download_file_from_url("", "")
+ assert not download_file_from_url("http://test.com", "")
+ assert not download_file_from_url("", "test.txt")
+
+
+@pytest.mark.parametrize(
+ "env_vars",
+ [
+ {},
+ {"HTTP_PROXY": "http://proxy:8080"},
+ {"HTTPS_PROXY": "https://proxy:8080"},
+ {"NO_PROXY": "localhost"},
+ ],
+)
+def test_download_file_from_url_proxy_handling(env_vars) -> None:
+ with patch.dict(os.environ, env_vars, clear=True), patch("requests.get") as mock_get:
+ mock_get.return_value.status_code = 404
+ assert not download_file_from_url("http://test.com", "test.txt")
+
+
+@patch("ardupilot_methodic_configurator.backend_internet.requests_get")
+def test_download_file_success(mock_get, tmp_path) -> None:
+ # Setup mock response
+ mock_response = Mock()
+ mock_response.headers = {"content-length": "100"}
+ mock_response.iter_content.return_value = [b"test data"]
+ mock_response.raise_for_status.return_value = None
+ mock_get.return_value = mock_response
+
+ test_file = tmp_path / "test.txt"
+ assert download_file_from_url("http://test.com", str(test_file))
+ assert test_file.read_bytes() == b"test data"
+
+
+@pytest.fixture
+def mock_get_() -> Mock:
+ with patch("ardupilot_methodic_configurator.backend_internet.requests_get") as _mock:
+ yield _mock
+
+
+@patch("ardupilot_methodic_configurator.backend_internet.requests_get")
+def test_download_file_invalid_content_length(mock_get) -> None:
+ # Test handling of invalid content-length header
+ mock_response = Mock()
+ mock_response.headers = {"content-length": "invalid"}
+ mock_response.iter_content.return_value = [b"test data"]
+ mock_get.return_value = mock_response
+ assert not download_file_from_url("http://test.com", "test.txt")
+
+
+@patch("ardupilot_methodic_configurator.backend_internet.requests_get")
+def test_download_file_missing_content_length(mock_get) -> None:
+ # Test handling of missing content-length header
+ mock_response = Mock()
+ mock_response.headers = {}
+ mock_response.iter_content.return_value = [b"test data"]
+ mock_get.return_value = mock_response
+ assert download_file_from_url("http://test.com", "test.txt")
+
+
+@patch("ardupilot_methodic_configurator.backend_internet.requests_get")
+def test_download_file_empty_response(mock_get) -> None:
+ # Test handling of empty response
+ mock_response = Mock()
+ mock_response.headers = {"content-length": "0"}
+ mock_response.iter_content.return_value = []
+ mock_get.return_value = mock_response
+ assert download_file_from_url("http://test.com", "test.txt")
+
+
+@patch("ardupilot_methodic_configurator.backend_internet.requests_get")
+def test_download_file_http_error(mock_get) -> None:
+ # Test HTTP error handling
+ mock_response = Mock()
+ mock_response.raise_for_status.side_effect = requests_HTTPError("404 Not Found")
+ mock_get.return_value = mock_response
+ assert not download_file_from_url("http://test.com", "test.txt")
+
+
+@patch("ardupilot_methodic_configurator.backend_internet.requests_get")
+def test_download_file_with_progress_no_content_length(mock_get, tmp_path) -> None:
+ # Test progress callback without content-length header
+ mock_response = Mock()
+ mock_response.headers = {}
+ mock_response.iter_content.return_value = [b"data"] * 4
+ mock_response.raise_for_status.return_value = None
+ mock_get.return_value = mock_response
+
+ progress_callback = Mock()
+ test_file = tmp_path / "test.txt"
+
+ assert download_file_from_url("http://test.com", str(test_file), progress_callback=progress_callback)
+ assert progress_callback.call_count == 1 # Only final callback
+ progress_callback.assert_called_with(100.0, "Download complete")
+
+
+@patch("ardupilot_methodic_configurator.backend_internet.requests_get")
+def test_download_file_proxy_configuration(mock_get, monkeypatch) -> None:
+ # Test proxy configuration handling
+ mock_response = Mock()
+ mock_response.headers = {"content-length": "100"}
+ mock_response.iter_content.return_value = [b"test data"]
+ mock_get.return_value = mock_response
+
+ # Set environment variables
+ monkeypatch.setenv("HTTP_PROXY", "http://proxy:8080")
+ monkeypatch.setenv("HTTPS_PROXY", "https://proxy:8080")
+ monkeypatch.setenv("NO_PROXY", "localhost")
+
+ assert download_file_from_url("http://test.com", "test.txt")
+ mock_get.assert_called_once_with(
+ "http://test.com",
+ stream=True,
+ timeout=30,
+ proxies={"http": "http://proxy:8080", "https": "https://proxy:8080", "no_proxy": "localhost"},
+ verify=True,
+ )
+
+
+@patch("ardupilot_methodic_configurator.backend_internet.requests_get")
+def test_download_file_value_error(mock_get) -> None:
+ # Test handling of ValueError during download
+ mock_response = Mock()
+ mock_response.headers = {"content-length": "100"}
+ mock_response.iter_content.side_effect = ValueError("Invalid data")
+ mock_get.return_value = mock_response
+ assert not download_file_from_url("http://test.com", "test.txt")
+
+
+def test_download_file_from_url_invalid_url() -> None:
+ # Test with invalid URL format
+ assert not download_file_from_url("not_a_valid_url", "test.txt")
+
+
+@patch("ardupilot_methodic_configurator.backend_internet.requests_get")
+def test_download_file_unicode_error(mock_get) -> None:
+ # Test handling of Unicode decode errors
+ mock_response = Mock()
+ mock_response.headers = {"content-length": "100"}
+ mock_response.iter_content.return_value = [bytes([0xFF, 0xFE, 0xFD])] # Invalid UTF-8
+ mock_get.return_value = mock_response
+ assert download_file_from_url("http://test.com", "test.txt")
+
+
+@patch("ardupilot_methodic_configurator.backend_internet.requests_get")
+def test_get_release_info_invalid_release(mock_get) -> None:
+ mock_get.side_effect = requests_RequestException()
+ with pytest.raises(requests_RequestException):
+ get_release_info("latest", False) # noqa: FBT003
+
+
+@patch("ardupilot_methodic_configurator.backend_internet.requests_get")
+def test_get_release_info_prerelease_mismatch(mock_get) -> None:
+ mock_response = Mock()
+ mock_response.json.return_value = {"prerelease": True}
+ mock_get.return_value = mock_response
+
+
+@patch("ardupilot_methodic_configurator.backend_internet.download_file_from_url")
+def test_download_and_install_windows_download_failure(mock_download) -> None:
+ mock_download.return_value = False
+ assert not download_and_install_on_windows("http://test.com", "test.exe")
+
+
+@patch("os.system")
+def test_download_and_install_pip_release(mock_system) -> None:
+ mock_system.return_value = 0
+ assert download_and_install_pip_release() == 0
+
+
+class TestDownloadFile:
+ @pytest.fixture
+ def mock_response(self) -> Mock:
+ response = Mock()
+ response.headers = {"content-length": "100"}
+ response.iter_content.return_value = [b"test data"]
+ response.raise_for_status.return_value = None
+ return response
+
+ @pytest.fixture
+ def mock_get(self) -> Mock:
+ with patch("ardupilot_methodic_configurator.backend_internet.requests_get") as _mock:
+ yield _mock
+
+ def test_download_file_network_errors(self, mock_get, caplog) -> None:
+ errors = [
+ requests_HTTPError("404 Not Found"),
+ requests_RequestException("Connection failed"),
+ ValueError("Invalid response"),
+ OSError("File system error"),
+ ]
+
+ for error in errors:
+ mock_get.side_effect = error
+ assert not download_file_from_url("http://test.com", "test.txt")
+ assert str(error) in caplog.text
+ caplog.clear()
+
+ def test_download_file_progress_tracking(self, mock_get, mock_response, tmp_path) -> None:
+ mock_get.return_value = mock_response
+ progress_values = []
+
+ def progress_callback(progress: float, msg: str) -> None:
+ progress_values.append((progress, msg))
+
+ test_file = tmp_path / "test.txt"
+ assert download_file_from_url("http://test.com", str(test_file), progress_callback=progress_callback)
+
+ # Verify progress tracking
+ assert len(progress_values) > 0
+ assert progress_values[-1][0] == 100.0
+ assert "Download complete" in progress_values[-1][1]
+ assert test_file.exists()
+ assert test_file.read_bytes() == b"test data"
+
+ def test_download_file_proxy_configs(self, mock_get, monkeypatch) -> None:
+ proxy_configs = [
+ {"HTTP_PROXY": "http://proxy1:8080"},
+ {"HTTPS_PROXY": "https://proxy2:8080"},
+ {"HTTP_PROXY": "http://proxy3:8080", "NO_PROXY": "localhost"},
+ ]
+
+ mock_response = Mock()
+ mock_response.headers = {"content-length": "100"}
+ mock_response.iter_content.return_value = [b"test data"]
+ mock_get.return_value = mock_response
+
+ for config in proxy_configs:
+ # Clear previous env vars
+ monkeypatch.delenv("HTTP_PROXY", raising=False)
+ monkeypatch.delenv("HTTPS_PROXY", raising=False)
+ monkeypatch.delenv("NO_PROXY", raising=False)
+
+ # Set new config
+ for key, value in config.items():
+ monkeypatch.setenv(key, value)
+
+ download_file_from_url("http://test.com", "test.txt")
+
+ expected_proxies = {}
+ if "HTTP_PROXY" in config:
+ expected_proxies["http"] = config["HTTP_PROXY"]
+ if "HTTPS_PROXY" in config:
+ expected_proxies["https"] = config["HTTPS_PROXY"]
+ if "NO_PROXY" in config:
+ expected_proxies["no_proxy"] = config["NO_PROXY"]
+
+ mock_get.assert_called_with("http://test.com", stream=True, timeout=30, proxies=expected_proxies, verify=True)
+ mock_get.reset_mock()
+
+ def test_download_file_filesystem_operations(self, mock_get, mock_response, tmp_path) -> None:
+ mock_get.return_value = mock_response
+
+ # Test directory creation
+ nested_path = tmp_path / "deep" / "nested" / "path"
+ test_file = nested_path / "test.txt"
+
+ assert download_file_from_url("http://test.com", str(test_file))
+ assert test_file.exists()
+ assert test_file.read_bytes() == b"test data"
+
+ # Test file overwrite
+ assert download_file_from_url("http://test.com", str(test_file))
+ assert test_file.exists()