Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
Signed-off-by: Benjamin Drung <[email protected]>
  • Loading branch information
bdrung committed Mar 1, 2024
1 parent b458644 commit 9ba23dc
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 77 deletions.
51 changes: 27 additions & 24 deletions bdebstrap
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import shutil
import subprocess
import sys
import time
import typing

import ruamel.yaml

Expand Down Expand Up @@ -70,20 +71,20 @@ class Config(dict):
_ENV_PREFIX = "BDEBSTRAP_"
_KEYS = {"env", "mmdebstrap", "name"}

def __init__(self, *args, **kwargs):
def __init__(self, *args: typing.Any, **kwargs: dict[str, typing.Any]) -> None:
super().__init__(self, *args, **kwargs)
self.logger = logging.getLogger(__script_name__)
self.yaml = ruamel.yaml.YAML()
self.yaml.explicit_start = True
self.yaml.indent(offset=2, sequence=4)

def _set_mmdebstrap_option(self, option, value):
def _set_mmdebstrap_option(self, option: str, value: str | int | list[str]) -> None:
"""Set the given mmdebstrap option (overwriting existing values)."""
if "mmdebstrap" not in self:
self["mmdebstrap"] = {}
self["mmdebstrap"][option] = value

def _append_mmdebstrap_option(self, option, value):
def _append_mmdebstrap_option(self, option: str, value: str | int | list[str]) -> None:
"""Append the given mmdebstrap option to the list of values."""
if "mmdebstrap" not in self:
self["mmdebstrap"] = {}
Expand All @@ -92,7 +93,8 @@ class Config(dict):
else:
self["mmdebstrap"][option] = value

def add_command_line_arguments(self, args): # pylint: disable=too-many-branches
# pylint: disable-next=too-many-branches
def add_command_line_arguments(self, args: argparse.Namespace) -> None:
"""Add/Override configs from the given command line arguments."""
for config_filename in args.config:
self.load(config_filename)
Expand Down Expand Up @@ -148,7 +150,7 @@ class Config(dict):
if args.mirrors:
self._append_mmdebstrap_option("mirrors", args.mirrors)

def env_items(self):
def env_items(self) -> list[tuple[str, str]]:
"""Return key-value pair of environment variables."""
return sorted(
list(self.get("env", {}).items())
Expand All @@ -159,7 +161,7 @@ class Config(dict):
]
)

def check(self):
def check(self) -> None:
"""Check the format of the configuration."""
unknown_top_level_keys = sorted(k for k in self.keys() if k not in self._KEYS)
if unknown_top_level_keys:
Expand All @@ -178,6 +180,7 @@ class Config(dict):
f"Excepted: {MMDEBSTRAP_OPTS[key]}."
)
if MMDEBSTRAP_OPTS[key] is list:
assert isinstance(value, list)
# Check if list elements are strings
for element in value:
if not isinstance(element, str):
Expand All @@ -191,7 +194,7 @@ class Config(dict):
if "name" not in self:
raise ValueError("The configuration does not contain a 'name' entry.")

def load(self, config_filename):
def load(self, config_filename: str) -> None:
"""Loading configuration from given config file."""
self.logger.info("Loading configuration from '%s'...", config_filename)
try:
Expand Down Expand Up @@ -232,7 +235,7 @@ class Config(dict):

self["mmdebstrap"]["packages"] = list(packages.values())

def save(self, config_filename, simulate=False):
def save(self, config_filename: str, simulate: bool = False) -> None:
"""Save configuration to given config filename."""
self.logger.info(
"%s configuration to '%s'.",
Expand All @@ -247,11 +250,11 @@ class Config(dict):
clamp_mtime(config_filename, self.source_date_epoch)

@property
def source_date_epoch(self):
def source_date_epoch(self) -> typing.Any:
"""Return SOURCE_DATE_EPOCH (for reproducible builds)."""
return self.get("env", {}).get("SOURCE_DATE_EPOCH")

def set_source_date_epoch(self):
def set_source_date_epoch(self) -> None:
"""Set SOURCE_DATE_EPOCH (for reproducible builds) if not already set."""

if "env" not in self:
Expand All @@ -263,7 +266,7 @@ class Config(dict):
class Mmdebstrap:
"""Wrapper around calling mmdebstrap."""

def __init__(self, config):
def __init__(self, config: Config) -> None:
self.config = config
self.logger = logging.getLogger(__script_name__)

Expand All @@ -277,7 +280,7 @@ class Mmdebstrap:
return ["-v"]
return []

def construct_parameters(self, output_dir, simulate=False):
def construct_parameters(self, output_dir: str, simulate: bool = False) -> list[str]:
"""Construct the parameter for mmdebstrap from a given dictionary."""
# pylint: disable=too-many-branches
cmd = ["mmdebstrap"] + self._get_mmdebstrap_log_level_parameters()
Expand Down Expand Up @@ -339,14 +342,14 @@ class Mmdebstrap:

return cmd

def call(self, output_dir, simulate=False):
def call(self, output_dir: str, simulate: bool = False) -> None:
"""Call mmdebstrap."""
cmd = self.construct_parameters(output_dir, simulate)
self.logger.info("Calling %s", escape_cmd(cmd))
subprocess.check_call(cmd)
self.clamp_mtime(output_dir)

def clamp_mtime(self, output_dir):
def clamp_mtime(self, output_dir: str) -> None:
"""Clamp the modification time of the manifest, target, and output directory."""
for path in (
os.path.join(output_dir, "manifest"),
Expand All @@ -362,7 +365,7 @@ class Mmdebstrap:
)


def clamp_mtime(path, source_date_epoch):
def clamp_mtime(path: str, source_date_epoch: int | str | None) -> None:
"""Clamp the modification time for the given path to SOURCE_DATE_EPOCH."""
if not source_date_epoch:
return
Expand All @@ -371,7 +374,7 @@ def clamp_mtime(path, source_date_epoch):
os.utime(path, (int(source_date_epoch), int(source_date_epoch)))


def duration_str(duration):
def duration_str(duration: float) -> str:
"""Return duration in the biggest useful time unit (hours, minutes, seconds)."""
if duration < 60:
return f"{duration:.3f} seconds"
Expand All @@ -382,11 +385,11 @@ def duration_str(duration):
return f"{minutes // 60} h {minutes % 60} min {duration % 60:.3f} s (= {duration:.3f} s)"


def escape_cmd(cmd):
def escape_cmd(cmd: list[str]) -> str:
"""Escape command line arguments for printing/logging."""
unsafe_re = re.compile(r"[^\w@%+=:,./-]", re.ASCII)

def quote(cmd_argv):
def quote(cmd_argv: str) -> str:
"""Return a shell-escaped version of the string *cmd_argv*."""
if unsafe_re.search(cmd_argv) is None:
return cmd_argv
Expand All @@ -399,14 +402,15 @@ def escape_cmd(cmd):
return " ".join(quote(x) for x in cmd)


def sanitize_list(list_):
def sanitize_list(list_: list[str]) -> list[str]:
"""Sanitize given list by removing all empty entries."""
if list_ is None:
return None
return [x for x in list_ if x]


def parse_args(argv): # pylint: disable=too-many-statements
# pylint: disable-next=too-many-statements
def parse_args(argv: list[str]) -> argparse.Namespace:
"""Parse the given command line arguments."""
parser = argparse.ArgumentParser(description=__doc__)
# parser.add_argument("-m", "--manifest", help="Store packages manifest in given file")
Expand Down Expand Up @@ -711,7 +715,7 @@ def parse_args(argv): # pylint: disable=too-many-statements
return args


def dict_merge(this, other):
def dict_merge(this: dict[str, typing.Any], other: dict[str, typing.Any]) -> None:
"""
Update this dictionary with the key/value pairs from other, merging existing keys.
Return ``None``.
Expand All @@ -723,7 +727,6 @@ def dict_merge(this, other):
:param this: dictionary onto which the merge is executed
:param other: dictionary merged into ``this``
:return: None
"""
for key in other.keys():
if (
Expand All @@ -742,7 +745,7 @@ def dict_merge(this, other):
this[key] = other[key]


def prepare_output_dir(output_dir, force, simulate=False):
def prepare_output_dir(output_dir: str, force: bool, simulate: bool = False) -> bool:
"""Ensure that the output directory exists and is empty."""
logger = logging.getLogger(__script_name__)

Expand Down Expand Up @@ -780,7 +783,7 @@ def prepare_output_dir(output_dir, force, simulate=False):
return True


def main(argv):
def main(argv: list[str]) -> int:
"""Call mmdebstrap with parameters specified in a YAML file."""
start_time = time.time()
args = parse_args(argv)
Expand Down
12 changes: 6 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ class DocCommand(Command):
"""A custom command to build the documentation using pandoc."""

description = "run pandoc to generate man pages"
user_options = []
user_options: list[tuple[str, str, str]] = []

def initialize_options(self):
def initialize_options(self) -> None:
"""Set default values for options."""

def finalize_options(self):
def finalize_options(self) -> None:
"""Post-process options."""

def run(self):
def run(self) -> None:
"""Run pandoc."""
for man_page in MAN_PAGES:
command = ["pandoc", "-s", "-t", "man", man_page + ".md", "-o", man_page]
Expand All @@ -54,15 +54,15 @@ def run(self):
class BuildCommand(distutils.command.build.build):
"""Custom build command (calling doc beforehand)."""

def run(self):
def run(self) -> None:
self.run_command("doc")
super().run()


class CleanCommand(distutils.command.clean.clean):
"""Custom clean command (removing generated man pages)."""

def run(self):
def run(self) -> None:
for man_page in MAN_PAGES:
if os.path.exists(man_page):
self.announce(f"removing {man_page}", level=distutils.log.INFO)
Expand Down
Loading

0 comments on commit 9ba23dc

Please sign in to comment.