diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..eac60ceb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,180 @@ +# https://peps.python.org/pep-0517/ +[build-system] +requires = ["setuptools>=64", "setuptools_scm>=8"] +build-backend = "setuptools.build_meta" + +# https://peps.python.org/pep-0621/ +[project] +name = "toolz" +description = "List processing tools and functional utilities" +readme = "README.rst" +requires-python = ">=3.8" +license = { text = "BSD 3-Clause License" } +maintainers = [{ name = "Erik Welch", email = "erik.n.welch@gmail.com" }] +authors = [ + { name = "Matthew Rocklin"}, + { name = "John Jacobsen"}, + { name = "Erik Welch"}, + { name = "John Crichton"}, + { name = "Han Semaj"}, + { name = "Graeme Coupar"}, + { name = "Leonid Shvechikov"}, + { name = "Lars Buitinck"}, + { name = "José Ricardo"}, + { name = "Tom Prince"}, + { name = "Bart van Merriënboer"}, + { name = "Nikolaos-Digenis Karagiannis"}, + { name = "Antonio Lima"}, + { name = "Joe Jevnik"}, + { name = "Rory Kirchner"}, + { name = "Steven Cutting"}, + { name = "Aric Coady"}, +] +keywords = ["functional", "utility", "itertools", "functools"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Typing :: Typed", +] +dynamic = ["version"] +dependencies = ["typing_extensions"] + +# extras +# https://peps.python.org/pep-0621/#dependencies-optional-dependencies +[project.optional-dependencies] +docs = [ + "sphinx-build", +] +test = [ + "pytest>=6.0", + "pytest-cov", +] + +[project.urls] +homepage = "https://github.com/pytoolz/toolz/" +repository = "https://github.com/pytoolz/toolz/" +documentation = "https://toolz.readthedocs.io/" + + +[tool.setuptools] +packages = [ + "toolz", + "toolz.sandbox", + "toolz.curried", + "tlz", +] +include-package-data = true + +[tool.setuptools.package-data] +toolz = [ + "py.typed", + "tests", +] + + +# https://docs.astral.sh/ruff/ +[tool.ruff] +line-length = 80 +src = ["toolz", "tlz"] +extend-exclude = [ + "examples", + "doc", + "bench", +] + +[tool.ruff.lint] +pydocstyle = { convention = "numpy" } +select = [ + "E", # style errors + "F", # flakes + "W", # warnings + "D417", # Missing argument descriptions in Docstrings + "I", # isort + "UP", # pyupgrade + "S", # bandit + "C4", # flake8-comprehensions + "B", # flake8-bugbear + "A001", # flake8-builtins + "ARG", # flake8-unused-arguments + "RET", # flake8-return + "SIM", # flake8-simplify + "TCH", # flake8-typecheck + "TID", # flake8-tidy-imports + "RUF", # ruff-specific rules +] +exclude = [ + "toolz/__init__.py", + "toolz/compatibility.py", + "**/tests/*", +] + +[tool.ruff.lint.per-file-ignores] +"toolz/tests/*.py" = ["B", "S", "F401", "RUF012"] +"toolz/sandbox/tests/*.py" = ["S", "SIM", "RUF012"] +"examples/*.py" = ["S", "RUF012"] +"bench/*.py" = ["RUF012"] +"toolz/_signatures.py" = ["ARG005", "C408"] +"toolz/curried/*.py" = ["F401", "A001"] + +# https://docs.astral.sh/ruff/formatter/ +[tool.ruff.format] +docstring-code-format = false +quote-style = "preserve" +exclude = [ + "toolz/__init__.py", + "toolz/compatibility.py", + "**/tests/*", +] + +# https://mypy.readthedocs.io/en/stable/config_file.html +[tool.mypy] +files = "toolz/**/*.py" +exclude = ["/tests/"] +strict = true +disallow_any_generics = false +disallow_subclassing_any = false +show_error_codes = true +pretty = true + +# https://docs.pytest.org/en/6.2.x/customize.html +[tool.pytest.ini_options] +minversion = "6.0" +testpaths = [ + "toolz/tests", + "toolz/sandbox/tests", +] + +# https://coverage.readthedocs.io/en/6.4/config.html +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "if TYPE_CHECKING:", + "@overload", + "except ImportError", + "\\.\\.\\.", + "raise NotImplementedError()", +] +show_missing = true +[tool.coverage.run] +source = ["toolz"] +omit = [ + "toolz/tests/test*", + "toolz/*/tests/test*", + "toolz/compatibility.py", + "toolz/_version.py", +] + +[tool.setuptools_scm] +version_file = "toolz/_version.py" +version_scheme = "post-release" +local_scheme = "dirty-tag" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 40974348..00000000 --- a/setup.cfg +++ /dev/null @@ -1,19 +0,0 @@ -[versioneer] -VCS = git -style = pep440 -versionfile_source = toolz/_version.py -versionfile_build = toolz/_version.py -tag_prefix = -parentdir_prefix = toolz- - -[coverage:run] -source = toolz -omit = - toolz/tests/test* - toolz/*/tests/test* - toolz/compatibility.py - toolz/_version.py - -[coverage:report] -exclude_lines = - pragma: no cover diff --git a/setup.py b/setup.py deleted file mode 100755 index 487303c0..00000000 --- a/setup.py +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env python - -from os.path import exists -from setuptools import setup -import versioneer - -setup(name='toolz', - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - description='List processing tools and functional utilities', - url='https://github.com/pytoolz/toolz/', - author='https://raw.github.com/pytoolz/toolz/master/AUTHORS.md', - maintainer='Erik Welch', - maintainer_email='erik.n.welch@gmail.com', - license='BSD', - keywords='functional utility itertools functools', - packages=['toolz', - 'toolz.sandbox', - 'toolz.curried', - 'tlz'], - package_data={'toolz': ['tests/*.py']}, - long_description=(open('README.rst').read() if exists('README.rst') - else ''), - zip_safe=False, - python_requires=">=3.7", - classifiers=[ - "Development Status :: 5 - Production/Stable", - "License :: OSI Approved :: BSD License", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy"]) diff --git a/tlz/__init__.py b/tlz/__init__.py index 9c9c84af..2c3ab7b9 100644 --- a/tlz/__init__.py +++ b/tlz/__init__.py @@ -7,3 +7,5 @@ """ from . import _build_tlz + +__all__ = ["_build_tlz"] diff --git a/tlz/_build_tlz.py b/tlz/_build_tlz.py index 3ac78369..0d9cd99e 100644 --- a/tlz/_build_tlz.py +++ b/tlz/_build_tlz.py @@ -1,20 +1,29 @@ +from __future__ import annotations + +import contextlib import sys -import types -import toolz from importlib import import_module +from importlib.abc import Loader from importlib.machinery import ModuleSpec +from types import ModuleType +from typing import TYPE_CHECKING + +import toolz +if TYPE_CHECKING: + from collections.abc import Sequence -class TlzLoader: + +class TlzLoader(Loader): """ Finds and loads ``tlz`` modules when added to sys.meta_path""" - def __init__(self): + def __init__(self) -> None: self.always_from_toolz = { toolz.pipe, } - def _load_toolz(self, fullname): + def _load_toolz(self, fullname: str) -> dict[str, ModuleType]: rv = {} - package, dot, submodules = fullname.partition('.') + _, dot, submodules = fullname.partition('.') try: module_name = ''.join(['cytoolz', dot, submodules]) rv['cytoolz'] = import_module(module_name) @@ -29,12 +38,17 @@ def _load_toolz(self, fullname): raise ImportError(fullname) return rv - def find_module(self, fullname, path=None): # pragma: py3 no cover - package, dot, submodules = fullname.partition('.') + def find_module( + self, + fullname: str, + path: Sequence[str] | None = None, # noqa: ARG002 + ) -> TlzLoader | None: # pragma: py3 no cover + package, _, __ = fullname.partition('.') if package == 'tlz': return self + return None - def load_module(self, fullname): # pragma: py3 no cover + def load_module(self, fullname: str) -> ModuleType: # pragma: py3 no cover if fullname in sys.modules: # pragma: no cover return sys.modules[fullname] spec = ModuleSpec(fullname, self) @@ -43,15 +57,21 @@ def load_module(self, fullname): # pragma: py3 no cover self.exec_module(module) return module - def find_spec(self, fullname, path, target=None): # pragma: no cover - package, dot, submodules = fullname.partition('.') + def find_spec( + self, + fullname: str, + path: Sequence[str] | None, # noqa: ARG002 + target: ModuleType | None = None, # noqa: ARG002 + ) -> ModuleSpec | None: # pragma: no cover + package, _, __ = fullname.partition('.') if package == 'tlz': return ModuleSpec(fullname, self) + return None - def create_module(self, spec): - return types.ModuleType(spec.name) + def create_module(self, spec: ModuleSpec) -> ModuleType: + return ModuleType(spec.name) - def exec_module(self, module): + def exec_module(self, module: ModuleType) -> None: toolz_mods = self._load_toolz(module.__name__) fast_mod = toolz_mods.get('cytoolz') or toolz_mods['toolz'] slow_mod = toolz_mods.get('toolz') or toolz_mods['cytoolz'] @@ -64,10 +84,8 @@ def exec_module(self, module): module.__doc__ = fast_mod.__doc__ # show file from toolz during introspection - try: + with contextlib.suppress(AttributeError): module.__file__ = slow_mod.__file__ - except AttributeError: - pass for k, v in fast_mod.__dict__.items(): tv = slow_mod.__dict__.get(k) @@ -78,7 +96,7 @@ def exec_module(self, module): if tv in self.always_from_toolz: module.__dict__[k] = tv elif ( - isinstance(v, types.ModuleType) + isinstance(v, ModuleType) and v.__package__ == fast_mod.__name__ ): package, dot, submodules = v.__name__.partition('.') diff --git a/toolz/__init__.py b/toolz/__init__.py index ba49a662..4e28c4f4 100644 --- a/toolz/__init__.py +++ b/toolz/__init__.py @@ -1,3 +1,5 @@ +from functools import partial, reduce + from .itertoolz import * from .functoolz import * @@ -6,8 +8,6 @@ from .recipes import * -from functools import partial, reduce - sorted = sorted map = map @@ -19,8 +19,8 @@ from . import curried, sandbox -functoolz._sigs.create_signature_registry() +from .functoolz import _sigs # type: ignore[attr-defined] + +_sigs.create_signature_registry() -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions +from ._version import __version__ diff --git a/toolz/_signatures.py b/toolz/_signatures.py index 27229ef4..ca96a936 100644 --- a/toolz/_signatures.py +++ b/toolz/_signatures.py @@ -12,16 +12,38 @@ Everything in this module should be regarded as implementation details. Users should try to not use this module directly. """ + +from __future__ import annotations + +import builtins import functools import inspect import itertools import operator from importlib import import_module +from typing import TYPE_CHECKING, Any, Callable, Mapping + +from .functoolz import ( + has_keywords, + has_varargs, + is_arity, + is_partial_args, + num_required_args, +) -from .functoolz import (is_partial_args, is_arity, has_varargs, - has_keywords, num_required_args) +if TYPE_CHECKING: + from types import ModuleType + + ExpandSignatureTuple = tuple[ + int, + Callable, + tuple[str, ...], + inspect.Signature | None, + ] + ExpandSignatureInput = ( + Callable | tuple[int, Callable] | tuple[int, Callable, tuple[str, ...]] + ) -import builtins # We mock builtin callables using lists of tuples with lambda functions. # @@ -38,7 +60,7 @@ # keyword_only_args: (optional) # - Tuple of keyword-only arguments. -module_info = {} +module_info: dict[str | ModuleType, Any] = {} module_info[builtins] = dict( abs=[ @@ -595,33 +617,43 @@ ) -def num_pos_args(sigspec): +def num_pos_args(sigspec: inspect.Signature | None) -> int: """ Return the number of positional arguments. ``f(x, y=1)`` has 1""" - return sum(1 for x in sigspec.parameters.values() - if x.kind == x.POSITIONAL_OR_KEYWORD - and x.default is x.empty) - - -def get_exclude_keywords(num_pos_only, sigspec): + if sigspec is None: + return -1 + return sum( + 1 + for x in sigspec.parameters.values() + if x.kind == x.POSITIONAL_OR_KEYWORD and x.default is x.empty + ) + + +def get_exclude_keywords( + num_pos_only: int, + sigspec: inspect.Signature | None, +) -> tuple[str, ...]: """ Return the names of position-only arguments if func has **kwargs""" - if num_pos_only == 0: + if num_pos_only == 0 or sigspec is None: return () - has_kwargs = any(x.kind == x.VAR_KEYWORD - for x in sigspec.parameters.values()) + has_kwargs = any( + x.kind == x.VAR_KEYWORD for x in sigspec.parameters.values() + ) if not has_kwargs: return () pos_args = list(sigspec.parameters.values())[:num_pos_only] return tuple(x.name for x in pos_args) -def signature_or_spec(func): +def signature_or_spec(func: Callable) -> inspect.Signature | None: try: return inspect.signature(func) except (ValueError, TypeError): return None -def expand_sig(sig): +def expand_sig( + sig: ExpandSignatureInput, +) -> ExpandSignatureTuple: """ Convert the signature spec in ``module_info`` to add to ``signatures`` The input signature spec is one of: @@ -641,7 +673,7 @@ def expand_sig(sig): if isinstance(sig, tuple): if len(sig) == 3: num_pos_only, func, keyword_only = sig - assert isinstance(sig[-1], tuple) + # assert isinstance(sig[-1], tuple) else: num_pos_only, func = sig keyword_only = () @@ -655,10 +687,13 @@ def expand_sig(sig): return num_pos_only, func, keyword_only + keyword_exclude, sigspec -signatures = {} +signatures: dict[Callable, tuple[ExpandSignatureTuple, ...]] = {} -def create_signature_registry(module_info=module_info, signatures=signatures): +def create_signature_registry( + module_info: dict = module_info, + signatures: dict[Callable, tuple[ExpandSignatureTuple, ...]] = signatures, +) -> None: for module, info in module_info.items(): if isinstance(module, str): module = import_module(module) @@ -668,7 +703,11 @@ def create_signature_registry(module_info=module_info, signatures=signatures): signatures[getattr(module, name)] = new_sigs -def check_valid(sig, args, kwargs): +def check_valid( + sig: ExpandSignatureTuple, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], +) -> bool | None: """ Like ``is_valid_args`` for the given signature spec""" num_pos_only, func, keyword_exclude, sigspec = sig if len(args) < num_pos_only: @@ -684,7 +723,11 @@ def check_valid(sig, args, kwargs): return False -def _is_valid_args(func, args, kwargs): +def _is_valid_args( + func: Callable, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], +) -> bool | None: """ Like ``is_valid_args`` for builtins in our ``signatures`` registry""" if func not in signatures: return None @@ -692,7 +735,11 @@ def _is_valid_args(func, args, kwargs): return any(check_valid(sig, args, kwargs) for sig in sigs) -def check_partial(sig, args, kwargs): +def check_partial( + sig: ExpandSignatureTuple, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], +) -> bool | None: """ Like ``is_partial_args`` for the given signature spec""" num_pos_only, func, keyword_exclude, sigspec = sig if len(args) < num_pos_only: @@ -705,7 +752,11 @@ def check_partial(sig, args, kwargs): return is_partial_args(func, args, kwargs, sigspec) -def _is_partial_args(func, args, kwargs): +def _is_partial_args( + func: Callable, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], +) -> bool | None: """ Like ``is_partial_args`` for builtins in our ``signatures`` registry""" if func not in signatures: return None @@ -713,67 +764,67 @@ def _is_partial_args(func, args, kwargs): return any(check_partial(sig, args, kwargs) for sig in sigs) -def check_arity(n, sig): +def check_arity(n: int, sig: ExpandSignatureTuple) -> bool | None: num_pos_only, func, keyword_exclude, sigspec = sig if keyword_exclude or num_pos_only > n: return False return is_arity(n, func, sigspec) -def _is_arity(n, func): +def _is_arity(n: int, func: Callable) -> bool | None: if func not in signatures: return None sigs = signatures[func] checks = [check_arity(n, sig) for sig in sigs] if all(checks): return True - elif any(checks): + if any(checks): return None return False -def check_varargs(sig): +def check_varargs(sig: ExpandSignatureTuple) -> bool | None: num_pos_only, func, keyword_exclude, sigspec = sig return has_varargs(func, sigspec) -def _has_varargs(func): +def _has_varargs(func: Callable) -> bool | None: if func not in signatures: return None sigs = signatures[func] checks = [check_varargs(sig) for sig in sigs] if all(checks): return True - elif any(checks): + if any(checks): return None return False -def check_keywords(sig): +def check_keywords(sig: ExpandSignatureTuple) -> bool | None: num_pos_only, func, keyword_exclude, sigspec = sig if keyword_exclude: return True return has_keywords(func, sigspec) -def _has_keywords(func): +def _has_keywords(func: Callable) -> bool | None: if func not in signatures: return None sigs = signatures[func] checks = [check_keywords(sig) for sig in sigs] if all(checks): return True - elif any(checks): + if any(checks): return None return False -def check_required_args(sig): +def check_required_args(sig: ExpandSignatureTuple) -> int | bool | None: num_pos_only, func, keyword_exclude, sigspec = sig return num_required_args(func, sigspec) -def _num_required_args(func): +def _num_required_args(func: Callable) -> int | bool | None: if func not in signatures: return None sigs = signatures[func] diff --git a/toolz/_version.py b/toolz/_version.py index 6e0bd8ff..b7d4c259 100644 --- a/toolz/_version.py +++ b/toolz/_version.py @@ -1,520 +1,16 @@ - -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "" - cfg.parentdir_prefix = "toolz-" - cfg.versionfile_source = "toolz/_version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} +# file generated by setuptools_scm +# don't change, don't track in version control +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple, Union + VERSION_TUPLE = Tuple[Union[int, str], ...] +else: + VERSION_TUPLE = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE + +__version__ = version = '0.12.1.post0+dirty' +__version_tuple__ = version_tuple = (0, 12, 1, 'dirty') diff --git a/toolz/compatibility.py b/toolz/compatibility.py index 28bef91d..b82d9971 100644 --- a/toolz/compatibility.py +++ b/toolz/compatibility.py @@ -1,9 +1,13 @@ import warnings -warnings.warn("The toolz.compatibility module is no longer " - "needed in Python 3 and has been deprecated. Please " - "import these utilities directly from the standard library. " - "This module will be removed in a future release.", - category=DeprecationWarning, stacklevel=2) + +warnings.warn( + "The toolz.compatibility module is no longer " + "needed in Python 3 and has been deprecated. Please " + "import these utilities directly from the standard library. " + "This module will be removed in a future release.", + category=DeprecationWarning, + stacklevel=2, +) import operator import sys @@ -22,9 +26,8 @@ range = range zip = zip from functools import reduce -from itertools import zip_longest -from itertools import filterfalse +from itertools import filterfalse, zip_longest + iteritems = operator.methodcaller('items') iterkeys = operator.methodcaller('keys') itervalues = operator.methodcaller('values') -from collections.abc import Sequence diff --git a/toolz/curried/__init__.py b/toolz/curried/__init__.py index 356eddbd..ea1e5e65 100644 --- a/toolz/curried/__init__.py +++ b/toolz/curried/__init__.py @@ -23,10 +23,9 @@ See Also: toolz.functoolz.curry """ + import toolz -from . import operator from toolz import ( - apply, comp, complement, compose, @@ -36,6 +35,7 @@ count, curry, diff, + excepts, first, flip, frequencies, @@ -53,9 +53,12 @@ thread_first, thread_last, ) + +from . import operator from .exceptions import merge, merge_with accumulate = toolz.curry(toolz.accumulate) +apply = toolz.curry(toolz.apply) assoc = toolz.curry(toolz.assoc) assoc_in = toolz.curry(toolz.assoc_in) cons = toolz.curry(toolz.cons) @@ -63,8 +66,7 @@ dissoc = toolz.curry(toolz.dissoc) do = toolz.curry(toolz.do) drop = toolz.curry(toolz.drop) -excepts = toolz.curry(toolz.excepts) -filter = toolz.curry(toolz.filter) +filter: toolz.curry = toolz.curry(toolz.filter) get = toolz.curry(toolz.get) get_in = toolz.curry(toolz.get_in) groupby = toolz.curry(toolz.groupby) @@ -78,14 +80,14 @@ map = toolz.curry(toolz.map) mapcat = toolz.curry(toolz.mapcat) nth = toolz.curry(toolz.nth) -partial = toolz.curry(toolz.partial) +partial = toolz.curry(toolz.partial) # type: ignore[attr-defined] partition = toolz.curry(toolz.partition) partition_all = toolz.curry(toolz.partition_all) partitionby = toolz.curry(toolz.partitionby) peekn = toolz.curry(toolz.peekn) pluck = toolz.curry(toolz.pluck) random_sample = toolz.curry(toolz.random_sample) -reduce = toolz.curry(toolz.reduce) +reduce = toolz.curry(toolz.reduce) # type: ignore[attr-defined] reduceby = toolz.curry(toolz.reduceby) remove = toolz.curry(toolz.remove) sliding_window = toolz.curry(toolz.sliding_window) @@ -99,5 +101,4 @@ valfilter = toolz.curry(toolz.valfilter) valmap = toolz.curry(toolz.valmap) -del exceptions del toolz diff --git a/toolz/curried/exceptions.py b/toolz/curried/exceptions.py index 75a52bbb..0405b08d 100644 --- a/toolz/curried/exceptions.py +++ b/toolz/curried/exceptions.py @@ -1,16 +1,39 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import toolz +if TYPE_CHECKING: + from collections.abc import Mapping, MutableMapping + from typing import Callable, Sequence, TypeVar + + _S = TypeVar('_S') + _T = TypeVar('_T') + _U = TypeVar('_U') + + _DictType = MutableMapping[_S, _T] + __all__ = ['merge_with', 'merge'] @toolz.curry -def merge_with(func, d, *dicts, **kwargs): +def merge_with( + func: Callable[[Sequence[_T]], _U], + d: Mapping[_S, _T], + *dicts: Mapping[_S, _T], + **kwargs: type[_DictType], +) -> _DictType[_S, _U]: return toolz.merge_with(func, d, *dicts, **kwargs) @toolz.curry -def merge(d, *dicts, **kwargs): +def merge( + d: Mapping[_S, _T], + *dicts: Mapping[_S, _T], + **kwargs: type[_DictType], +) -> _DictType[_S, _T]: return toolz.merge(d, *dicts, **kwargs) diff --git a/toolz/curried/operator.py b/toolz/curried/operator.py index 35979a68..c0fd8bf4 100644 --- a/toolz/curried/operator.py +++ b/toolz/curried/operator.py @@ -1,19 +1,20 @@ -from __future__ import absolute_import import operator from toolz.functoolz import curry - # Tests will catch if/when this needs updated IGNORE = { - "__abs__", "__index__", "__inv__", "__invert__", "__neg__", "__not__", - "__pos__", "_abs", "abs", "attrgetter", "index", "inv", "invert", - "itemgetter", "neg", "not_", "pos", "truth" + '__abs__', '__index__', '__inv__', '__invert__', '__neg__', '__not__', + '__pos__', '_abs', 'abs', 'attrgetter', 'index', 'inv', 'invert', + 'itemgetter', 'neg', 'not_', 'pos', 'truth' } locals().update( - {name: f if name in IGNORE else curry(f) - for name, f in vars(operator).items() if callable(f)} + { + name: f if name in IGNORE else curry(f) + for name, f in vars(operator).items() + if callable(f) + } ) # Clean up the namespace. diff --git a/toolz/dicttoolz.py b/toolz/dicttoolz.py index 457bc269..22c3e366 100644 --- a/toolz/dicttoolz.py +++ b/toolz/dicttoolz.py @@ -1,22 +1,32 @@ +from __future__ import annotations + import operator -import collections +from collections import defaultdict +from collections.abc import Callable, Mapping, MutableMapping from functools import reduce -from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Sequence __all__ = ('merge', 'merge_with', 'valmap', 'keymap', 'itemmap', 'valfilter', 'keyfilter', 'itemfilter', 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in') +if TYPE_CHECKING: + from typing import TypeVar -def _get_factory(f, kwargs): - factory = kwargs.pop('factory', dict) - if kwargs: - raise TypeError("{}() got an unexpected keyword argument " - "'{}'".format(f.__name__, kwargs.popitem()[0])) - return factory + _S = TypeVar('_S') + _T = TypeVar('_T') + _U = TypeVar('_U') + _V = TypeVar('_V') + _DictType = MutableMapping[_S, _T] + Predicate = Callable[[_T], Any] + TransformOp = Callable[[_T], _S] + Filter = Callable[[_T], bool] -def merge(*dicts, **kwargs): +def merge( + *dicts: Mapping[_S, _T], + factory: type[_DictType] = dict, +) -> _DictType[_S, _T]: """ Merge a collection of dictionaries >>> merge({1: 'one'}, {2: 'two'}) @@ -32,7 +42,6 @@ def merge(*dicts, **kwargs): """ if len(dicts) == 1 and not isinstance(dicts[0], Mapping): dicts = dicts[0] - factory = _get_factory(merge, kwargs) rv = factory() for d in dicts: @@ -40,7 +49,11 @@ def merge(*dicts, **kwargs): return rv -def merge_with(func, *dicts, **kwargs): +def merge_with( + func: Callable[[Sequence[_T]], _U], + *dicts: Mapping[_S, _T], + factory: type[_DictType] = dict, +) -> _DictType[_S, _U]: """ Merge dictionaries and apply function to combined values A key may occur in more than one dict, and all values mapped from the key @@ -57,20 +70,23 @@ def merge_with(func, *dicts, **kwargs): """ if len(dicts) == 1 and not isinstance(dicts[0], Mapping): dicts = dicts[0] - factory = _get_factory(merge_with, kwargs) - values = collections.defaultdict(lambda: [].append) + values = defaultdict(list) for d in dicts: for k, v in d.items(): - values[k](v) + values[k].append(v) result = factory() - for k, v in values.items(): - result[k] = func(v.__self__) + for k, list_v in values.items(): + result[k] = func(list_v) return result -def valmap(func, d, factory=dict): +def valmap( + func: TransformOp[_T, _U], + d: Mapping[_S, _T], + factory: type[_DictType] = dict, +) -> _DictType[_S, _U]: """ Apply function to values of dictionary >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} @@ -86,7 +102,11 @@ def valmap(func, d, factory=dict): return rv -def keymap(func, d, factory=dict): +def keymap( + func: TransformOp[_S, _U], + d: Mapping[_S, _T], + factory: type[_DictType] = dict, +) -> _DictType[_U, _T]: """ Apply function to keys of dictionary >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} @@ -102,7 +122,11 @@ def keymap(func, d, factory=dict): return rv -def itemmap(func, d, factory=dict): +def itemmap( + func: Callable[[tuple[_S, _T]], tuple[_U, _V]], + d: Mapping[_S, _T], + factory: type[_DictType] = dict, +) -> _DictType[_U, _V]: """ Apply function to items of dictionary >>> accountids = {"Alice": 10, "Bob": 20} @@ -118,7 +142,11 @@ def itemmap(func, d, factory=dict): return rv -def valfilter(predicate, d, factory=dict): +def valfilter( + predicate: Filter[_T], + d: Mapping[_S, _T], + factory: type[_DictType] = dict, +) -> _DictType[_S, _T]: """ Filter items in dictionary by value >>> iseven = lambda x: x % 2 == 0 @@ -138,7 +166,11 @@ def valfilter(predicate, d, factory=dict): return rv -def keyfilter(predicate, d, factory=dict): +def keyfilter( + predicate: Filter[_S], + d: Mapping[_S, _T], + factory: type[_DictType] = dict, +) -> _DictType[_S, _T]: """ Filter items in dictionary by key >>> iseven = lambda x: x % 2 == 0 @@ -158,7 +190,11 @@ def keyfilter(predicate, d, factory=dict): return rv -def itemfilter(predicate, d, factory=dict): +def itemfilter( + predicate: Filter[tuple[_S, _T]], + d: Mapping[_S, _T], + factory: type[_DictType] = dict, +) -> _DictType[_S, _T]: """ Filter items in dictionary by item >>> def isvalid(item): @@ -182,7 +218,12 @@ def itemfilter(predicate, d, factory=dict): return rv -def assoc(d, key, value, factory=dict): +def assoc( + d: Mapping[_S, _T], + key: _S, + value: _T, + factory: type[_DictType] = dict, +) -> _DictType[_S, _T]: """ Return a new dict with new key value pair New dict has d[key] set to value. Does not modify the initial dictionary. @@ -198,7 +239,12 @@ def assoc(d, key, value, factory=dict): return d2 -def dissoc(d, *keys, **kwargs): +def dissoc( + d: Mapping[_S, _T], + *keys: _S, + factory: type[_DictType] = dict, +) -> _DictType[_S, _T]: + """ Return a new dict with the given key(s) removed. New dict has d[key] deleted for each supplied key. @@ -211,10 +257,9 @@ def dissoc(d, *keys, **kwargs): >>> dissoc({'x': 1}, 'y') # Ignores missing keys {'x': 1} """ - factory = _get_factory(dissoc, kwargs) d2 = factory() - if len(keys) < len(d) * .6: + if len(keys) < len(d) * 0.6: d2.update(d) for key in keys: if key in d2: @@ -227,7 +272,12 @@ def dissoc(d, *keys, **kwargs): return d2 -def assoc_in(d, keys, value, factory=dict): +def assoc_in( + d: Mapping[_S, _T], + keys: Sequence, + value: Any, + factory: type[_DictType] | None = None, +) -> _DictType[_S, _T]: """ Return a new dict with new, potentially nested, key value pair >>> purchase = {'name': 'Alice', @@ -239,10 +289,16 @@ def assoc_in(d, keys, value, factory=dict): 'name': 'Alice', 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}} """ - return update_in(d, keys, lambda x: value, value, factory) + return update_in(d, keys, lambda _: value, value, factory) -def update_in(d, keys, func, default=None, factory=dict): +def update_in( + d: Mapping[_S, _T], + keys: Sequence, + func: Callable, + default: Any = None, + factory: type[_DictType] | None = None, +) -> _DictType[_S, _T]: """ Update value in a (potentially) nested dictionary inputs: @@ -276,19 +332,20 @@ def update_in(d, keys, func, default=None, factory=dict): >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0) {1: 'foo', 2: {3: {4: 1}}} """ + dict_factory = factory or dict ks = iter(keys) k = next(ks) - rv = inner = factory() + rv = inner = dict_factory() rv.update(d) for key in ks: if k in d: - d = d[k] - dtemp = factory() + d = d[k] # type: ignore[assignment] + dtemp: _DictType = dict_factory() dtemp.update(d) else: - d = dtemp = factory() + d = dtemp = dict_factory() inner[k] = inner = dtemp k = key @@ -300,7 +357,12 @@ def update_in(d, keys, func, default=None, factory=dict): return rv -def get_in(keys, coll, default=None, no_default=False): +def get_in( + keys: Sequence, + coll: Mapping, + default: Any = None, + no_default: bool = False, +) -> Any: """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless diff --git a/toolz/functoolz.py b/toolz/functoolz.py index 7709f15b..02c4a5b9 100644 --- a/toolz/functoolz.py +++ b/toolz/functoolz.py @@ -1,13 +1,55 @@ -from functools import reduce, partial +from __future__ import annotations + +import contextlib import inspect import sys -from operator import attrgetter, not_ +from functools import partial, reduce from importlib import import_module +from operator import attrgetter, not_ from types import MethodType +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload + +if sys.version_info < (3, 10): + from typing_extensions import ParamSpec +else: + from typing import ParamSpec from .utils import no_default -PYPY = hasattr(sys, 'pypy_version_info') and sys.version_info[0] > 2 +_S = TypeVar('_S') +_T = TypeVar('_T') +_U = TypeVar('_U') +_Instance = TypeVar('_Instance') +_P = ParamSpec('_P') + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping, Sequence + from types import NotImplementedType + + from typing_extensions import Literal, TypeGuard, TypeVarTuple, Unpack + + _Ts = TypeVarTuple('_Ts') + + Getter = Callable[[_Instance], _T] + Setter = Callable[[_Instance, _T], None] + Deleter = Callable[[_Instance], None] + InstancePropertyState = tuple[ + Getter[_Instance, _T] | None, + Setter[_Instance, _T] | None, + Deleter[_Instance] | None, + str | None, + _T, + ] + TransformOp = Callable[[_T], _S] + TupleTransformBack = tuple[ + Callable[[Unpack[_Ts], _S], _T], + Unpack[_Ts], + ] # >= py311 + TupleTransformFront = tuple[ + Callable[[_S, Unpack[_Ts]], _T], + Unpack[_Ts], + ] # >= py311 + _CurryState = tuple __all__ = ('identity', 'apply', 'thread_first', 'thread_last', 'memoize', @@ -17,7 +59,7 @@ PYPY = hasattr(sys, 'pypy_version_info') -def identity(x): +def identity(x: _T) -> _T: """ Identity function. Return x >>> identity(3) @@ -26,7 +68,7 @@ def identity(x): return x -def apply(*func_and_args, **kwargs): +def apply(func: Callable[..., _T], /, *args: Any, **kwargs: Any) -> _T: """ Applies a function and returns the results >>> def double(x): return 2*x @@ -37,13 +79,10 @@ def apply(*func_and_args, **kwargs): >>> tuple(map(apply, [double, inc, double], [10, 500, 8000])) (20, 501, 16000) """ - if not func_and_args: - raise TypeError('func argument is required') - func, args = func_and_args[0], func_and_args[1:] return func(*args, **kwargs) -def thread_first(val, *forms): +def thread_first(val: Any, *forms: TransformOp | TupleTransformFront) -> Any: """ Thread value through a sequence of functions/forms >>> def double(x): return 2*x @@ -67,17 +106,30 @@ def thread_first(val, *forms): See Also: thread_last """ - def evalform_front(val, form): + + @overload + def evalform_front(val: _S, form: TransformOp[_S, _T]) -> _T: ... + + @overload + def evalform_front( + val: _S, form: TupleTransformFront[_S, Unpack[_Ts], _T] + ) -> _T: ... + + def evalform_front( + val: _S, + form: TransformOp[_S, _T] | TupleTransformFront[_S, Unpack[_Ts], _T], + ) -> _T: if callable(form): return form(val) - if isinstance(form, tuple): - func, args = form[0], form[1:] - args = (val,) + args - return func(*args) - return reduce(evalform_front, forms, val) + # if isinstance(form, tuple): + func, args = form[0], form[1:] + all_args = (val, *args) + return func(*all_args) + + return reduce(evalform_front, forms, val) # type: ignore[arg-type] -def thread_last(val, *forms): +def thread_last(val: Any, *forms: TransformOp | TupleTransformBack) -> Any: """ Thread value through a sequence of functions/forms >>> def double(x): return 2*x @@ -106,17 +158,59 @@ def thread_last(val, *forms): See Also: thread_first """ - def evalform_back(val, form): - if callable(form): - return form(val) - if isinstance(form, tuple): - func, args = form[0], form[1:] - args = args + (val,) - return func(*args) - return reduce(evalform_back, forms, val) + @overload + def evalform_back(val: _S, form: TransformOp[_S, _T]) -> _T: ... + + @overload + def evalform_back( + val: _S, form: TupleTransformBack[Unpack[_Ts], _S, _T] + ) -> _T: ... -def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None): + def evalform_back( + val: _S, + form: TransformOp[_S, _T] | TupleTransformBack[Unpack[_Ts], _S, _T], + ) -> _T: + if callable(form): + return form(val) + # if isinstance(form, tuple): + func, args = form[0], form[1:] + all_args = (*args, val) + return func(*all_args) + + return reduce(evalform_back, forms, val) # type: ignore[arg-type] + + +@overload +def instanceproperty( + fget: Getter[_Instance, _T], + fset: Setter[_Instance, _T] | None = ..., + fdel: Deleter[_Instance] | None = ..., + doc: str | None = ..., + classval: _T | None = ..., +) -> InstanceProperty[_Instance, _T]: ... + + +@overload +def instanceproperty( + fget: Literal[None] | None = None, + fset: Setter[_Instance, _T] | None = ..., + fdel: Deleter[_Instance] | None = ..., + doc: str | None = ..., + classval: _T | None = ..., +) -> Callable[[Getter[_Instance, _T]], InstanceProperty[_Instance, _T]]: ... + + +def instanceproperty( + fget: Getter[_Instance, _T] | None = None, + fset: Setter[_Instance, _T] | None = None, + fdel: Deleter[_Instance] | None = None, + doc: str | None = None, + classval: _T | None = None, +) -> ( + InstanceProperty[_Instance, _T] + | Callable[[Getter[_Instance, _T]], InstanceProperty[_Instance, _T]] +): """ Like @property, but returns ``classval`` when used as a class attribute >>> class MyClass(object): @@ -139,33 +233,63 @@ def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None): 42 """ if fget is None: - return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc, - classval=classval) - return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc, - classval=classval) + return partial( + instanceproperty, fset=fset, fdel=fdel, doc=doc, classval=classval + ) + return InstanceProperty( + fget=fget, fset=fset, fdel=fdel, doc=doc, classval=classval + ) -class InstanceProperty(property): +class InstanceProperty(Generic[_Instance, _T], property): """ Like @property, but returns ``classval`` when used as a class attribute Should not be used directly. Use ``instanceproperty`` instead. """ - def __init__(self, fget=None, fset=None, fdel=None, doc=None, - classval=None): + + def __init__( + self, + fget: Getter[_Instance, _T] | None = None, + fset: Setter[_Instance, _T] | None = None, + fdel: Deleter[_Instance] | None = None, + doc: str | None = None, + classval: _T | None = None, + ) -> None: self.classval = classval property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc) - def __get__(self, obj, type=None): + @overload + def __get__(self, obj: None, type: type | None = ...) -> _T | None: ... + + @overload + def __get__(self, obj: _Instance, type: type | None = ...) -> _T: ... + + def __get__( + self, obj: _Instance | None, type: type | None = None + ) -> _T | None: if obj is None: return self.classval - return property.__get__(self, obj, type) + return cast(_T, property.__get__(self, obj, type)) - def __reduce__(self): + def __reduce__( + self, + ) -> tuple[type[InstanceProperty], InstancePropertyState]: state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval) return InstanceProperty, state -class curry(object): +def is_partial_function(func: Callable) -> TypeGuard[partial]: + if ( + hasattr(func, 'func') + and hasattr(func, 'args') + and hasattr(func, 'keywords') + and isinstance(func.args, tuple) + ): + return True + return False + + +class curry(Generic[_T]): """ Curry a callable function Enables partial application of arguments through calling a function with an @@ -181,7 +305,7 @@ class curry(object): Also supports keyword arguments - >>> @curry # Can use curry as a decorator + >>> @curry # Can use curry as a decorator ... def f(x, y, a=10): ... return a * (x + y) @@ -193,20 +317,20 @@ class curry(object): toolz.curried - namespace of curried functions https://toolz.readthedocs.io/en/latest/curry.html """ - def __init__(self, *args, **kwargs): - if not args: - raise TypeError('__init__() takes at least 2 arguments (1 given)') - func, args = args[0], args[1:] + + def __init__( + self, + # TODO: type hint that the returned value of a `partial` is _T + func: curry[_T] | partial | Callable[..., _T], + /, # `func` is positional only, cannot be passed as keyword + *args: Any, + **kwargs: Any, + ) -> None: if not callable(func): - raise TypeError("Input must be callable") + raise TypeError('Input must be callable') # curry- or functools.partial-like object? Unpack and merge arguments - if ( - hasattr(func, 'func') - and hasattr(func, 'args') - and hasattr(func, 'keywords') - and isinstance(func.args, tuple) - ): + if is_partial_function(func): _kwargs = {} if func.keywords: _kwargs.update(func.keywords) @@ -222,17 +346,17 @@ def __init__(self, *args, **kwargs): self.__doc__ = getattr(func, '__doc__', None) self.__name__ = getattr(func, '__name__', '') - self.__module__ = getattr(func, '__module__', None) - self.__qualname__ = getattr(func, '__qualname__', None) - self._sigspec = None - self._has_unknown_args = None + self.__module__ = getattr(func, '__module__', '') + self.__qualname__ = getattr(func, '__qualname__', '') + self._sigspec: inspect.Signature | None = None + self._has_unknown_args: bool | None = None @instanceproperty - def func(self): + def func(self) -> Callable[..., _T]: return self._partial.func @instanceproperty - def __signature__(self): + def __signature__(self) -> inspect.Signature: sig = inspect.signature(self.func) args = self.args or () keywords = self.keywords or {} @@ -241,7 +365,7 @@ def __signature__(self): params = list(sig.parameters.values()) skip = 0 - for param in params[:len(args)]: + for param in params[: len(args)]: if param.kind == param.VAR_POSITIONAL: break skip += 1 @@ -270,44 +394,57 @@ def __signature__(self): return sig.replace(parameters=newparams) @instanceproperty - def args(self): + def args(self) -> tuple[Any, ...]: return self._partial.args @instanceproperty - def keywords(self): + def keywords(self) -> dict[str, Any]: return self._partial.keywords @instanceproperty - def func_name(self): + def func_name(self) -> str: return self.__name__ - def __str__(self): + def __str__(self) -> str: return str(self.func) - def __repr__(self): + def __repr__(self) -> str: return repr(self.func) - def __hash__(self): - return hash((self.func, self.args, - frozenset(self.keywords.items()) if self.keywords - else None)) + def __hash__(self) -> int: + return hash( + ( + self.func, + self.args, + frozenset(self.keywords.items()) if self.keywords else None, + ) + ) - def __eq__(self, other): - return (isinstance(other, curry) and self.func == other.func and - self.args == other.args and self.keywords == other.keywords) + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, curry) + and self.func == other.func + and self.args == other.args + and self.keywords == other.keywords + ) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> _T | curry[_T]: try: - return self._partial(*args, **kwargs) + return self.call(*args, **kwargs) except TypeError as exc: if self._should_curry(args, kwargs, exc): return self.bind(*args, **kwargs) raise - def _should_curry(self, args, kwargs, exc=None): + def _should_curry( + self, + args: tuple[Any, ...], + kwargs: Mapping, + exc: Exception | None = None, # noqa: ARG002 + ) -> bool: func = self.func args = self.args + args if self.keywords: @@ -321,35 +458,34 @@ def _should_curry(self, args, kwargs, exc=None): if is_partial_args(func, args, kwargs, sigspec) is False: # Nothing can make the call valid return False - elif self._has_unknown_args: + if self._has_unknown_args: # The call may be valid and raised a TypeError, but we curry # anyway because the function may have `*args`. This is useful # for decorators with signature `func(*args, **kwargs)`. return True - elif not is_valid_args(func, args, kwargs, sigspec): + if not is_valid_args(func, args, kwargs, sigspec): # Adding more arguments may make the call valid return True - else: - # There was a genuine TypeError - return False + # There was a genuine TypeError + return False - def bind(self, *args, **kwargs): + def bind(self, *args: Any, **kwargs: Any) -> curry[_T]: return type(self)(self, *args, **kwargs) - def call(self, *args, **kwargs): - return self._partial(*args, **kwargs) + def call(self, *args: Any, **kwargs: Any) -> _T: + return cast(_T, self._partial(*args, **kwargs)) - def __get__(self, instance, owner): + def __get__(self, instance: object, owner: type) -> curry[_T]: if instance is None: return self return curry(self, instance) - def __reduce__(self): + def __reduce__(self) -> tuple[Callable, _CurryState]: func = self.func - modname = getattr(func, '__module__', None) - qualname = getattr(func, '__qualname__', None) + modname = getattr(func, '__module__', '') + qualname = getattr(func, '__qualname__', '') if qualname is None: # pragma: no cover - qualname = getattr(func, '__name__', None) + qualname = getattr(func, '__name__', '') is_decorated = None if modname and qualname: attrs = [] @@ -357,47 +493,69 @@ def __reduce__(self): for attr in qualname.split('.'): if isinstance(obj, curry): attrs.append('func') - obj = obj.func - obj = getattr(obj, attr, None) + obj = obj.func # type: ignore[assignment] + obj = getattr(obj, attr, None) # type: ignore[assignment] if obj is None: break attrs.append(attr) if isinstance(obj, curry) and obj.func is func: is_decorated = obj is self qualname = '.'.join(attrs) - func = '%s:%s' % (modname, qualname) + func = f'{modname}:{qualname}' # type: ignore[assignment] # functools.partial objects can't be pickled - userdict = tuple((k, v) for k, v in self.__dict__.items() - if k not in ('_partial', '_sigspec')) - state = (type(self), func, self.args, self.keywords, userdict, - is_decorated) + userdict = tuple( + (k, v) + for k, v in self.__dict__.items() + if k not in ('_partial', '_sigspec') + ) + state = ( + type(self), + func, + self.args, + self.keywords, + userdict, + is_decorated, + ) return _restore_curry, state -def _restore_curry(cls, func, args, kwargs, userdict, is_decorated): +def _restore_curry( + cls: type[curry[_T]], + func: str | Callable[..., _T], + args: tuple[Any, ...], + kwargs: dict[str, Any], + userdict: Mapping[str, Any], + is_decorated: bool, +) -> curry[_T]: if isinstance(func, str): modname, qualname = func.rsplit(':', 1) obj = import_module(modname) for attr in qualname.split('.'): obj = getattr(obj, attr) if is_decorated: - return obj + return obj # type: ignore[return-value] func = obj.func - obj = cls(func, *args, **(kwargs or {})) + # TODO: check that func is callable + func = cast('Callable[..., _T]', func) + obj = cls(func, *args, **(kwargs or {})) # type: ignore[assignment] obj.__dict__.update(userdict) - return obj + return obj # type: ignore[return-value] @curry -def memoize(func, cache=None, key=None): +def memoize( + func: Callable[..., _T], + cache: dict[Any, _T] | None = None, + key: Callable[[tuple, Mapping], Any] | None = None, +) -> Callable[..., _T]: """ Cache a function's result for speedy future evaluation Considerations: Trades memory for speed. Only use on pure functions. - >>> def add(x, y): return x + y + >>> def add(x, y): return x + y >>> add = memoize(add) Or use as a decorator @@ -439,117 +597,125 @@ def memoize(func, cache=None, key=None): if key is None: if is_unary: - def key(args, kwargs): + + def key(args: tuple, kwargs: Mapping) -> Any: # noqa: ARG001 return args[0] elif may_have_kwargs: - def key(args, kwargs): + + def key(args: tuple, kwargs: Mapping) -> Any: return ( args or None, frozenset(kwargs.items()) if kwargs else None, ) else: - def key(args, kwargs): + + def key(args: tuple, kwargs: Mapping) -> Any: # noqa: ARG001 return args - def memof(*args, **kwargs): + def memof(*args: Any, **kwargs: Any) -> _T: k = key(args, kwargs) try: return cache[k] - except TypeError: - raise TypeError("Arguments to memoized function must be hashable") + except TypeError as err: + msg = 'Arguments to memoized function must be hashable' + raise TypeError(msg) from err except KeyError: cache[k] = result = func(*args, **kwargs) return result - try: + with contextlib.suppress(AttributeError): memof.__name__ = func.__name__ - except AttributeError: - pass memof.__doc__ = func.__doc__ - memof.__wrapped__ = func + memof.__wrapped__ = func # type: ignore[attr-defined] return memof -class Compose(object): +# TODO: requires a mypy plugin to type check a function chain +# see `returns.pipeline.flow` +class Compose: """ A composition of functions See Also: compose """ - __slots__ = 'first', 'funcs' - def __init__(self, funcs): + __slots__ = ['first', 'funcs'] + + def __init__(self, funcs: Sequence[Callable]) -> None: funcs = tuple(reversed(funcs)) - self.first = funcs[0] - self.funcs = funcs[1:] + self.first: Callable = funcs[0] + self.funcs: tuple[Callable, ...] = funcs[1:] - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: ret = self.first(*args, **kwargs) for f in self.funcs: ret = f(ret) return ret - def __getstate__(self): + def __getstate__(self) -> tuple[Callable, tuple[Callable, ...]]: return self.first, self.funcs - def __setstate__(self, state): + def __setstate__( + self, state: tuple[Callable, tuple[Callable, ...]] + ) -> None: self.first, self.funcs = state @instanceproperty(classval=__doc__) - def __doc__(self): - def composed_doc(*fs): + def __doc__(self) -> str: # type: ignore[override] + def composed_doc(*fs: Callable) -> str: """Generate a docstring for the composition of fs. """ if not fs: # Argument name for the docstring. return '*args, **kwargs' - return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:])) + f = fs[0].__name__ + g_or_args = composed_doc(*fs[1:]) + return f'{f}({g_or_args})' try: - return ( - 'lambda *args, **kwargs: ' + - composed_doc(*reversed((self.first,) + self.funcs)) - ) + body = composed_doc(*reversed((self.first, *self.funcs))) + return f'lambda *args, **kwargs: {body}' except AttributeError: # One of our callables does not have a `__name__`, whatever. return 'A composition of functions' @property - def __name__(self): + def __name__(self) -> str: try: return '_of_'.join( - (f.__name__ for f in reversed((self.first,) + self.funcs)) + f.__name__ for f in reversed((self.first, *self.funcs)) ) except AttributeError: - return type(self).__name__ + return 'Compose' - def __repr__(self): - return '{.__class__.__name__}{!r}'.format( - self, tuple(reversed((self.first, ) + self.funcs))) + def __repr__(self) -> str: + name = self.__class__.__name__ + funcs = tuple(reversed((self.first, *self.funcs))) + return f'{name}{funcs!r}' - def __eq__(self, other): + def __eq__(self, other: Any) -> bool | NotImplementedType: if isinstance(other, Compose): return other.first == self.first and other.funcs == self.funcs return NotImplemented - def __ne__(self, other): + def __ne__(self, other: Any) -> bool | NotImplementedType: equality = self.__eq__(other) return NotImplemented if equality is NotImplemented else not equality - def __hash__(self): + def __hash__(self) -> int: return hash(self.first) ^ hash(self.funcs) # Mimic the descriptor behavior of python functions. # i.e. let Compose be called as a method when bound to a class. # adapted from # docs.python.org/3/howto/descriptor.html#functions-and-methods - def __get__(self, obj, objtype=None): + def __get__(self, obj: object, objtype: type | None = None) -> Any: return self if obj is None else MethodType(self, obj) # introspection with Signature is only possible from py3.3+ @instanceproperty - def __signature__(self): + def __signature__(self) -> inspect.Signature: base = inspect.signature(self.first) last = inspect.signature(self.funcs[-1]) return base.replace(return_annotation=last.return_annotation) @@ -557,7 +723,7 @@ def __signature__(self): __wrapped__ = instanceproperty(attrgetter('first')) -def compose(*funcs): +def compose(*funcs: Callable) -> Callable | Compose: """ Compose functions to operate in series. Returns a function that applies other functions in sequence. @@ -579,11 +745,10 @@ def compose(*funcs): return identity if len(funcs) == 1: return funcs[0] - else: - return Compose(funcs) + return Compose(funcs) -def compose_left(*funcs): +def compose_left(*funcs: Callable) -> Callable | Compose: """ Compose functions to operate in series. Returns a function that applies other functions in sequence. @@ -604,7 +769,7 @@ def compose_left(*funcs): return compose(*reversed(funcs)) -def pipe(data, *funcs): +def pipe(data: Any, *funcs: Callable) -> Any: """ Pipe a value through a sequence of functions I.e. ``pipe(data, f, g, h)`` is equivalent to ``h(g(f(data)))`` @@ -629,7 +794,7 @@ def pipe(data, *funcs): return data -def complement(func): +def complement(func: Callable[[Any], bool]) -> Compose: """ Convert a predicate function to its logical complement. In other words, return a function that, for inputs that normally @@ -642,10 +807,10 @@ def complement(func): >>> isodd(2) False """ - return compose(not_, func) + return cast(Compose, compose(not_, func)) -class juxt(object): +class juxt(Generic[_P, _T]): """ Creates a function that calls several functions with the same arguments Takes several functions and returns a function that applies its arguments @@ -661,24 +826,25 @@ class juxt(object): >>> juxt([inc, double])(10) (11, 20) """ + __slots__ = ['funcs'] - def __init__(self, *funcs): + def __init__(self, *funcs: Callable[_P, _T]) -> None: if len(funcs) == 1 and not callable(funcs[0]): funcs = funcs[0] - self.funcs = tuple(funcs) + self.funcs: tuple[Callable[_P, _T], ...] = tuple(funcs) - def __call__(self, *args, **kwargs): + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> tuple[_T, ...]: return tuple(func(*args, **kwargs) for func in self.funcs) - def __getstate__(self): + def __getstate__(self) -> tuple[Callable[_P, _T], ...]: return self.funcs - def __setstate__(self, state): + def __setstate__(self, state: tuple[Callable[_P, _T], ...]) -> None: self.funcs = state -def do(func, x): +def do(func: Callable[[_T], Any], x: _T) -> _T: """ Runs ``func`` on ``x``, returns ``x`` Because the results of ``func`` are not returned, only the side @@ -705,7 +871,7 @@ def do(func, x): @curry -def flip(func, a, b): +def flip(func: Callable[[_S, _T], _U], a: _T, b: _S) -> _U: """ Call the function call with the arguments flipped This function is curried. @@ -731,13 +897,13 @@ def flip(func, a, b): return func(b, a) -def return_none(exc): +def return_none(exc: Exception) -> Literal[None]: # noqa: ARG001 """ Returns None. """ return None -class excepts(object): +class excepts(Generic[_P, _T]): """A wrapper around a function to catch exceptions and dispatch to a handler. @@ -766,27 +932,36 @@ class excepts(object): >>> excepting({0: 1}) 1 """ - def __init__(self, exc, func, handler=return_none): + + exc: type[Exception] | tuple[type[Exception], ...] + func: Callable[_P, _T] + handler: Callable[[Exception], _T | None] + + def __init__( + self, + exc: type[Exception] | tuple[type[Exception], ...], + func: Callable[_P, _T], + handler: Callable[[Exception], _T | None] | None = None, + ) -> None: self.exc = exc self.func = func - self.handler = handler + self.handler = handler or return_none - def __call__(self, *args, **kwargs): + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T | None: try: return self.func(*args, **kwargs) except self.exc as e: return self.handler(e) @instanceproperty(classval=__doc__) - def __doc__(self): + def __doc__(self) -> str: # type: ignore[override] from textwrap import dedent exc = self.exc try: if isinstance(exc, tuple): - exc_name = '(%s)' % ', '.join( - map(attrgetter('__name__'), exc), - ) + names = ', '.join(map(attrgetter('__name__'), exc)) + exc_name = f'({names})' else: exc_name = exc.__name__ @@ -807,53 +982,67 @@ def __doc__(self): exc=exc_name, ) except AttributeError: - return type(self).__doc__ + return str(type(self).__doc__) @property - def __name__(self): + def __name__(self) -> str: exc = self.exc try: if isinstance(exc, tuple): exc_name = '_or_'.join(map(attrgetter('__name__'), exc)) else: exc_name = exc.__name__ - return '%s_excepting_%s' % (self.func.__name__, exc_name) + return f'{self.func.__name__}_excepting_{exc_name}' except AttributeError: return 'excepting' -def _check_sigspec(sigspec, func, builtin_func, *builtin_args): +def _has_signature_get(func: Callable) -> bool: + if func not in _sigs.signatures: + return False + if not hasattr(func, '__signature__') or not hasattr( + func.__signature__, '__get__' + ): + return False + return True + + +def _check_sigspec_orig( + sigspec: inspect.Signature | None, + func: Callable, + builtin_func: Callable[..., _S], + *builtin_args: Any, +) -> tuple[None, _S | bool | None] | tuple[inspect.Signature, None]: if sigspec is None: try: sigspec = inspect.signature(func) except (ValueError, TypeError) as e: - sigspec = e - if isinstance(sigspec, ValueError): - return None, builtin_func(*builtin_args) - elif not isinstance(sigspec, inspect.Signature): - if ( - func in _sigs.signatures - and (( - hasattr(func, '__signature__') - and hasattr(func.__signature__, '__get__') - )) - ): - val = builtin_func(*builtin_args) - return None, val - return None, False + if isinstance(e, ValueError): + return None, builtin_func(*builtin_args) + if _has_signature_get(func): + val = builtin_func(*builtin_args) + return None, val + return None, False return sigspec, None if PYPY: # pragma: no cover - _check_sigspec_orig = _check_sigspec - def _check_sigspec(sigspec, func, builtin_func, *builtin_args): + def _check_sigspec( + sigspec: inspect.Signature | None, + func: Callable, + builtin_func: Callable[..., _S], + *builtin_args: Any, + ) -> tuple[None, _S | bool | None] | tuple[inspect.Signature, None]: # PyPy may lie, so use our registry for builtins instead if func in _sigs.signatures: val = builtin_func(*builtin_args) return None, val return _check_sigspec_orig(sigspec, func, builtin_func, *builtin_args) +else: + _check_sigspec = _check_sigspec_orig + _check_sigspec.__doc__ = """ \ Private function to aid in introspection compatibly across Python versions. @@ -863,36 +1052,53 @@ def _check_sigspec(sigspec, func, builtin_func, *builtin_args): """ -def num_required_args(func, sigspec=None): - sigspec, rv = _check_sigspec(sigspec, func, _sigs._num_required_args, - func) +def num_required_args( + func: Callable, + sigspec: inspect.Signature | None = None, +) -> int | bool | None: + sigspec, rv = _check_sigspec(sigspec, func, _sigs._num_required_args, func) if sigspec is None: return rv - return sum(1 for p in sigspec.parameters.values() - if p.default is p.empty - and p.kind in (p.POSITIONAL_OR_KEYWORD, p.POSITIONAL_ONLY)) - - -def has_varargs(func, sigspec=None): + return sum( + 1 + for p in sigspec.parameters.values() + if p.default is p.empty + and p.kind in (p.POSITIONAL_OR_KEYWORD, p.POSITIONAL_ONLY) + ) + + +def has_varargs( + func: Callable, + sigspec: inspect.Signature | None = None, +) -> bool | None: sigspec, rv = _check_sigspec(sigspec, func, _sigs._has_varargs, func) if sigspec is None: return rv - return any(p.kind == p.VAR_POSITIONAL - for p in sigspec.parameters.values()) + return any(p.kind == p.VAR_POSITIONAL for p in sigspec.parameters.values()) -def has_keywords(func, sigspec=None): +def has_keywords( + func: Callable, + sigspec: inspect.Signature | None = None, +) -> bool | None: sigspec, rv = _check_sigspec(sigspec, func, _sigs._has_keywords, func) if sigspec is None: return rv - return any(p.default is not p.empty - or p.kind in (p.KEYWORD_ONLY, p.VAR_KEYWORD) - for p in sigspec.parameters.values()) - - -def is_valid_args(func, args, kwargs, sigspec=None): - sigspec, rv = _check_sigspec(sigspec, func, _sigs._is_valid_args, - func, args, kwargs) + return any( + p.default is not p.empty or p.kind in (p.KEYWORD_ONLY, p.VAR_KEYWORD) + for p in sigspec.parameters.values() + ) + + +def is_valid_args( + func: Callable, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + sigspec: inspect.Signature | None = None, +) -> bool | None: + sigspec, rv = _check_sigspec( + sigspec, func, _sigs._is_valid_args, func, args, kwargs + ) if sigspec is None: return rv try: @@ -902,9 +1108,15 @@ def is_valid_args(func, args, kwargs, sigspec=None): return True -def is_partial_args(func, args, kwargs, sigspec=None): - sigspec, rv = _check_sigspec(sigspec, func, _sigs._is_partial_args, - func, args, kwargs) +def is_partial_args( + func: Callable, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + sigspec: inspect.Signature | None = None, +) -> bool | None: + sigspec, rv = _check_sigspec( + sigspec, func, _sigs._is_partial_args, func, args, kwargs + ) if sigspec is None: return rv try: @@ -914,7 +1126,11 @@ def is_partial_args(func, args, kwargs, sigspec=None): return True -def is_arity(n, func, sigspec=None): +def is_arity( + n: int, + func: Callable, + sigspec: inspect.Signature | None = None, +) -> bool | None: """ Does a function have only n positional arguments? This function relies on introspection and does not call the function. @@ -1046,4 +1262,4 @@ def is_arity(n, func, sigspec=None): Many builtins in the standard library are also supported. """ -from . import _signatures as _sigs +from . import _signatures as _sigs # noqa: E402 diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index 634f4869..aacc48af 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -1,11 +1,53 @@ -import itertools -import heapq +from __future__ import annotations + import collections +import heapq +import itertools import operator +from collections.abc import ( + Callable, + Generator, + Iterator, + Mapping, + Sequence, + Sized, +) from functools import partial from itertools import filterfalse, zip_longest -from collections.abc import Sequence -from toolz.utils import no_default +from random import Random +from typing import TYPE_CHECKING, Any, Iterable, TypeVar, cast, overload + +from toolz.utils import no_default, no_pad + +_T = TypeVar('_T') +_S = TypeVar('_S') +_U = TypeVar('_U') + +if TYPE_CHECKING: + from abc import abstractmethod + from typing import Protocol + + from typing_extensions import TypeGuard # >= py310 + + from toolz.utils import NoDefaultType, NoPadType + + class Comparable(Protocol): + """Protocol for annotating comparable types.""" + + @abstractmethod + def __lt__(self: _CT, other: _CT) -> bool: + pass + + class Randomable(Protocol): + def random(self) -> float: ... + + + _CT = TypeVar('_CT', bound=Comparable) + Predicate = Callable[[_T], object] + BinaryOp = Callable[[_T, _T], _T] + UnaryOp = Callable[[_T], _T] + TransformOp = Callable[[_T], _S] + SeqOrMapping = Sequence[_T] | Mapping[Any, _T] __all__ = ('remove', 'accumulate', 'groupby', 'merge_sorted', 'interleave', @@ -16,7 +58,13 @@ 'join', 'tail', 'diff', 'topk', 'peek', 'peekn', 'random_sample') -def remove(predicate, seq): +def _is_no_default(x: Any) -> TypeGuard[NoDefaultType]: + if x is no_default: + return True + return False + + +def remove(predicate: Predicate[_T] | None, seq: Iterable[_T]) -> Iterator[_T]: """ Return those items of sequence for which predicate(item) is False >>> def iseven(x): @@ -27,7 +75,11 @@ def remove(predicate, seq): return filterfalse(predicate, seq) -def accumulate(binop, seq, initial=no_default): +def accumulate( + binop: BinaryOp[_T], + seq: Iterable[_T], + initial: _T | NoDefaultType = no_default, +) -> Generator[_T, None, None]: """ Repeatedly apply binary function to a sequence, accumulating results >>> from operator import add, mul @@ -55,20 +107,21 @@ def accumulate(binop, seq, initial=no_default): itertools.accumulate : In standard itertools for Python 3.2+ """ seq = iter(seq) - if initial == no_default: + if _is_no_default(initial): try: result = next(seq) except StopIteration: return else: - result = initial + result = cast(_T, initial) yield result for elem in seq: result = binop(result, elem) yield result -def groupby(key, seq): +# TODO: overload with key not callable +def groupby(key: TransformOp[_T, _S], seq: Iterable[_T]) -> dict[_S, list[_T]]: """ Group a collection by a key function >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] @@ -95,16 +148,17 @@ def groupby(key, seq): """ if not callable(key): key = getter(key) - d = collections.defaultdict(lambda: [].append) + d: dict[_S, list[_T]] = collections.defaultdict(list) for item in seq: - d[key(item)](item) - rv = {} - for k, v in d.items(): - rv[k] = v.__self__ - return rv + vals = d[key(item)] + vals.append(item) + return d -def merge_sorted(*seqs, **kwargs): +def merge_sorted( + *seqs: Iterable[_CT], + key: UnaryOp[_CT] | None = None, +) -> Iterator[_CT]: """ Merge and sort a collection of sorted collections This works lazily and only keeps one value from each iterable in memory. @@ -122,28 +176,20 @@ def merge_sorted(*seqs, **kwargs): """ if len(seqs) == 0: return iter([]) - elif len(seqs) == 1: + if len(seqs) == 1: return iter(seqs[0]) - key = kwargs.get('key', None) if key is None: return _merge_sorted_binary(seqs) - else: - return _merge_sorted_binary_key(seqs, key) + return _merge_sorted_binary_key(seqs, key) -def _merge_sorted_binary(seqs): +def _merge_sorted_binary(seqs: Sequence[Iterable[_CT]]) -> Iterator[_CT]: mid = len(seqs) // 2 L1 = seqs[:mid] - if len(L1) == 1: - seq1 = iter(L1[0]) - else: - seq1 = _merge_sorted_binary(L1) + seq1 = iter(L1[0]) if len(L1) == 1 else _merge_sorted_binary(L1) L2 = seqs[mid:] - if len(L2) == 1: - seq2 = iter(L2[0]) - else: - seq2 = _merge_sorted_binary(L2) + seq2 = iter(L2[0]) if len(L2) == 1 else _merge_sorted_binary(L2) try: val2 = next(seq2) @@ -175,18 +221,15 @@ def _merge_sorted_binary(seqs): yield val1 -def _merge_sorted_binary_key(seqs, key): +def _merge_sorted_binary_key( + seqs: Sequence[Iterable[_CT]], + key: UnaryOp[_CT], +) -> Iterator[_CT]: mid = len(seqs) // 2 L1 = seqs[:mid] - if len(L1) == 1: - seq1 = iter(L1[0]) - else: - seq1 = _merge_sorted_binary_key(L1, key) + seq1 = iter(L1[0]) if len(L1) == 1 else _merge_sorted_binary_key(L1, key) L2 = seqs[mid:] - if len(L2) == 1: - seq2 = iter(L2[0]) - else: - seq2 = _merge_sorted_binary_key(L2, key) + seq2 = iter(L2[0]) if len(L2) == 1 else _merge_sorted_binary_key(L2, key) try: val2 = next(seq2) @@ -221,7 +264,7 @@ def _merge_sorted_binary_key(seqs, key): yield val1 -def interleave(seqs): +def interleave(seqs: Sequence[Sequence[_T]]) -> Iterator[_T]: """ Interleave a sequence of sequences >>> list(interleave([[1, 2], [3, 4]])) @@ -234,7 +277,7 @@ def interleave(seqs): Returns a lazy iterator """ - iters = itertools.cycle(map(iter, seqs)) + iters: Iterator[Iterator[_T]] = itertools.cycle(map(iter, seqs)) while True: try: for itr in iters: @@ -245,7 +288,10 @@ def interleave(seqs): iters = itertools.cycle(itertools.takewhile(predicate, iters)) -def unique(seq, key=None): +def unique( + seq: Sequence[_T], + key: TransformOp[_T, Any] | None = None, +) -> Iterator[_T]: """ Return only unique elements of a sequence >>> tuple(unique((1, 2, 3))) @@ -259,21 +305,20 @@ def unique(seq, key=None): ('cat', 'mouse') """ seen = set() - seen_add = seen.add if key is None: for item in seq: if item not in seen: - seen_add(item) + seen.add(item) yield item else: # calculate key for item in seq: val = key(item) if val not in seen: - seen_add(val) + seen.add(val) yield item -def isiterable(x): +def isiterable(x: Any) -> TypeGuard[Iterable]: """ Is x iterable? >>> isiterable([1, 2, 3]) @@ -290,7 +335,7 @@ def isiterable(x): return False -def isdistinct(seq): +def isdistinct(seq: Iterable | Sequence) -> bool: """ All values in sequence are distinct >>> isdistinct([1, 2, 3]) @@ -303,19 +348,18 @@ def isdistinct(seq): >>> isdistinct("World") True """ - if iter(seq) is seq: - seen = set() - seen_add = seen.add - for item in seq: - if item in seen: - return False - seen_add(item) - return True - else: + if isinstance(seq, Sequence): return len(seq) == len(set(seq)) + seen = set() + for item in seq: + if item in seen: + return False + seen.add(item) + return True + -def take(n, seq): +def take(n: int, seq: Iterable[_T]) -> Iterator[_T]: """ The first n elements of a sequence >>> list(take(2, [10, 20, 30, 40, 50])) @@ -328,7 +372,7 @@ def take(n, seq): return itertools.islice(seq, n) -def tail(n, seq): +def tail(n: int, seq: Sequence[_T]) -> Sequence[_T]: """ The last n elements of a sequence >>> tail(2, [10, 20, 30, 40, 50]) @@ -344,7 +388,7 @@ def tail(n, seq): return tuple(collections.deque(seq, n)) -def drop(n, seq): +def drop(n: int, seq: Sequence[_T]) -> Iterator[_T]: """ The sequence following the first n elements >>> list(drop(2, [10, 20, 30, 40, 50])) @@ -357,7 +401,7 @@ def drop(n, seq): return itertools.islice(seq, n, None) -def take_nth(n, seq): +def take_nth(n: int, seq: Sequence[_T]) -> Iterator[_T]: """ Every nth item in seq >>> list(take_nth(2, [10, 20, 30, 40, 50])) @@ -366,7 +410,7 @@ def take_nth(n, seq): return itertools.islice(seq, 0, None, n) -def first(seq): +def first(seq: Iterable[_T]) -> _T: """ The first element in a sequence >>> first('ABC') @@ -375,7 +419,7 @@ def first(seq): return next(iter(seq)) -def second(seq): +def second(seq: Iterable[_T]) -> _T: """ The second element in a sequence >>> second('ABC') @@ -386,19 +430,18 @@ def second(seq): return next(seq) -def nth(n, seq): +def nth(n: int, seq: Iterable[_T] | Sequence[_T]) -> _T: """ The nth element in a sequence >>> nth(1, 'ABC') 'B' """ - if isinstance(seq, (tuple, list, Sequence)): + if isinstance(seq, Sequence): return seq[n] - else: - return next(itertools.islice(seq, n, None)) + return next(itertools.islice(seq, n, None)) -def last(seq): +def last(seq: Sequence[_T]) -> _T: """ The last element in a sequence >>> last('ABC') @@ -410,14 +453,32 @@ def last(seq): rest = partial(drop, 1) -def _get(ind, seq, default): +def _get(ind: Any, seq: SeqOrMapping[_T], default: _T) -> _T: try: return seq[ind] except (KeyError, IndexError): return default -def get(ind, seq, default=no_default): +@overload +def get( # type: ignore[overload-overlap] + ind: Sequence[Any], + seq: SeqOrMapping[_T], + default: _T | NoDefaultType = ..., +) -> tuple[_T, ...]: ... + +@overload +def get( + ind: Any, + seq: SeqOrMapping[_T], + default: _T | NoDefaultType = ..., +) -> _T: ... + +def get( + ind: Any | Sequence[Any], + seq: SeqOrMapping[_T], + default: _T | NoDefaultType = no_default, +) -> _T | tuple[_T, ...]: """ Get element in a sequence or dict Provides standard indexing @@ -451,30 +512,27 @@ def get(ind, seq, default=no_default): pluck """ try: - return seq[ind] + return seq[ind] # type: ignore[index] except TypeError: # `ind` may be a list if isinstance(ind, list): - if default == no_default: + if _is_no_default(default): if len(ind) > 1: - return operator.itemgetter(*ind)(seq) - elif ind: - return seq[ind[0]], - else: - return () - else: - return tuple(_get(i, seq, default) for i in ind) - elif default != no_default: - return default - else: - raise + return tuple(operator.itemgetter(*ind)(seq)) + if ind: + return (seq[ind[0]],) + return () + return tuple(_get(i, seq, cast(_T, default)) for i in ind) + + if not _is_no_default(default): + return cast(_T, default) + raise except (KeyError, IndexError): # we know `ind` is not a list - if default == no_default: - raise - else: - return default + if not _is_no_default(default): + return cast(_T, default) + raise -def concat(seqs): +def concat(seqs: Iterable[Iterable[_T]]) -> Iterator[_T]: """ Concatenate zero or more iterables, any of which may be infinite. An infinite sequence will prevent the rest of the arguments from @@ -492,7 +550,7 @@ def concat(seqs): return itertools.chain.from_iterable(seqs) -def concatv(*seqs): +def concatv(*seqs: Iterable[_T]) -> Iterator[_T]: """ Variadic version of concat >>> list(concatv([], ["a"], ["b", "c"])) @@ -504,7 +562,10 @@ def concatv(*seqs): return concat(seqs) -def mapcat(func, seqs): +def mapcat( + func: TransformOp[Iterable[_T], Iterable[_S]], + seqs: Iterable[Iterable[_T]], +) -> Iterator[_S]: """ Apply func to each sequence in seqs, concatenating results. >>> list(mapcat(lambda s: [c.upper() for c in s], @@ -514,7 +575,7 @@ def mapcat(func, seqs): return concat(map(func, seqs)) -def cons(el, seq): +def cons(el: _T, seq: Iterable[_T]) -> Iterator[_T]: """ Add el to beginning of (possibly infinite) sequence seq. >>> list(cons(1, [2, 3])) @@ -523,7 +584,7 @@ def cons(el, seq): return itertools.chain([el], seq) -def interpose(el, seq): +def interpose(el: _T, seq: Iterable[_T]) -> Iterator[_T]: """ Introduce element between each pair of elements in seq >>> list(interpose("a", [1, 2, 3])) @@ -534,7 +595,7 @@ def interpose(el, seq): return inposed -def frequencies(seq): +def frequencies(seq: Iterable[_T]) -> dict[_T, int]: """ Find number of occurrences of each value in seq >>> frequencies(['cat', 'cat', 'ox', 'pig', 'pig', 'cat']) #doctest: +SKIP @@ -544,13 +605,19 @@ def frequencies(seq): countby groupby """ - d = collections.defaultdict(int) + d: dict[_T, int] = collections.defaultdict(int) for item in seq: d[item] += 1 - return dict(d) + return d -def reduceby(key, binop, seq, init=no_default): +# TODO: overload with key not callable +def reduceby( + key: TransformOp[_T, _S], + binop: BinaryOp[_T], + seq: Iterable[_T], + init: _T | Callable[[], _T] | NoDefaultType = no_default, +) -> dict[_S, _T]: """ Perform a simultaneous groupby and reduction The computation: @@ -611,26 +678,25 @@ def reduceby(key, binop, seq, init=no_default): {True: set([2, 4]), False: set([1, 3])} """ - is_no_default = init == no_default - if not is_no_default and not callable(init): - _init = init - init = lambda: _init if not callable(key): key = getter(key) + d = {} for item in seq: k = key(item) if k not in d: - if is_no_default: + if _is_no_default(init): d[k] = item continue - else: + if callable(init): d[k] = init() + else: + d[k] = cast(_T, init) d[k] = binop(d[k], item) return d -def iterate(func, x): +def iterate(func: UnaryOp[_T], x: _T) -> Iterator[_T]: """ Repeatedly apply a function func onto an original input Yields x, then func(x), then func(func(x)), then func(func(func(x))), etc.. @@ -660,7 +726,7 @@ def iterate(func, x): x = func(x) -def sliding_window(n, seq): +def sliding_window(n: int, seq: Iterable[_T]) -> Iterator[tuple[_T, ...]]: """ A sequence of overlapping subsequences >>> list(sliding_window(2, [1, 2, 3, 4])) @@ -673,14 +739,19 @@ def sliding_window(n, seq): >>> list(map(mean, sliding_window(2, [1, 2, 3, 4]))) [1.5, 2.5, 3.5] """ - return zip(*(collections.deque(itertools.islice(it, i), 0) or it - for i, it in enumerate(itertools.tee(seq, n)))) - - -no_pad = '__no__pad__' - - -def partition(n, seq, pad=no_pad): + return zip( + *( + collections.deque(itertools.islice(it, i), 0) or it + for i, it in enumerate(itertools.tee(seq, n)) + ) + ) + + +def partition( + n: int, + seq: Iterable[_T], + pad: _S | NoPadType = no_pad, +) -> Iterator[tuple[_T | _S, ...]]: """ Partition sequence into tuples of length n >>> list(partition(2, [1, 2, 3, 4])) @@ -701,11 +772,10 @@ def partition(n, seq, pad=no_pad): args = [iter(seq)] * n if pad is no_pad: return zip(*args) - else: - return zip_longest(*args, fillvalue=pad) + return zip_longest(*args, fillvalue=pad) -def partition_all(n, seq): +def partition_all(n: int, seq: Sequence[_T]) -> Iterator[tuple[_T, ...]]: """ Partition all elements of sequence into tuples of length at most n The final tuple may be shorter to accommodate extra elements. @@ -719,6 +789,12 @@ def partition_all(n, seq): See Also: partition """ + + def cast_out(val: tuple) -> tuple[_T, ...]: + # Trick for type-checkers, `prev` type can contain `no_pad` + # so cast to a type without `no_pad` + return cast('tuple[_T, ...]', val) + args = [iter(seq)] * n it = zip_longest(*args, fillvalue=no_pad) try: @@ -726,13 +802,13 @@ def partition_all(n, seq): except StopIteration: return for item in it: - yield prev + yield cast_out(prev) prev = item if prev[-1] is no_pad: try: # If seq defines __len__, then # we can quickly calculate where no_pad starts - yield prev[:len(seq) % n] + yield cast_out(prev[: len(seq) % n]) except TypeError: # Get first index of no_pad without using .index() # https://github.com/pytoolz/toolz/issues/387 @@ -745,12 +821,12 @@ def partition_all(n, seq): hi = mid else: lo = mid + 1 - yield prev[:lo] + yield cast_out(prev[:lo]) else: - yield prev + yield cast_out(prev) -def count(seq): +def count(seq: Iterable) -> int: """ Count the number of items in seq Like the builtin ``len`` but works on lazy sequences. @@ -760,12 +836,32 @@ def count(seq): See also: len """ - if hasattr(seq, '__len__'): + if isinstance(seq, Sized): return len(seq) - return sum(1 for i in seq) + return sum(1 for _ in seq) + + +@overload +def pluck( # type: ignore[overload-overlap] + index: Sequence[Any], + seqs: Iterable[SeqOrMapping[_T]], + default: _T | NoDefaultType = ..., +) -> Iterator[tuple[_T, ...]]: ... -def pluck(ind, seqs, default=no_default): +@overload +def pluck( + index: Any, + seqs: Iterable[SeqOrMapping[_T]], + default: _T | NoDefaultType = ..., +) -> Iterator[_T]: ... + + +def pluck( # type: ignore[misc] + ind: Any | Sequence[Any], + seqs: Iterable[SeqOrMapping[_T]], + default: _T | NoDefaultType = no_default, +) -> Iterator[_T] | Iterator[tuple[_T, ...]]: """ plucks an element or several elements from each item in a sequence. ``pluck`` maps ``itertoolz.get`` over a sequence and returns one or more @@ -788,31 +884,51 @@ def pluck(ind, seqs, default=no_default): get map """ - if default == no_default: + if _is_no_default(default): get = getter(ind) return map(get, seqs) - elif isinstance(ind, list): - return (tuple(_get(item, seq, default) for item in ind) - for seq in seqs) - return (_get(ind, seq, default) for seq in seqs) + if isinstance(ind, list): + return ( + tuple(_get(item, seq, cast(_T, default)) for item in ind) + for seq in seqs + ) + return (_get(ind, seq, cast(_T, default)) for seq in seqs) + + +@overload +def getter( # type: ignore[overload-overlap] + index: list[Any], +) -> Callable[[SeqOrMapping[_T]], tuple[_T, ...]]: ... -def getter(index): + +@overload +def getter(index: Any) -> Callable[[Sequence[_T] | Mapping[Any, _T]], _T]: ... + + +def getter( + index: Any | list[Any], +) -> Callable[[SeqOrMapping[_T]], _T | tuple[_T, ...]]: if isinstance(index, list): if len(index) == 1: index = index[0] return lambda x: (x[index],) - elif index: + if index: return operator.itemgetter(*index) - else: - return lambda x: () - else: - return operator.itemgetter(index) + return lambda _: () + return operator.itemgetter(index) -def join(leftkey, leftseq, rightkey, rightseq, - left_default=no_default, right_default=no_default): - """ Join two sequences on common attributes +# TODO: overload with leftkey/rightkey not callable +def join( + leftkey: TransformOp[_T, Any], + leftseq: Iterable[_T], + rightkey: TransformOp[_T, Any], + rightseq: Iterable[_T], + left_default: _U | NoDefaultType = no_default, + right_default: _U | NoDefaultType = no_default, +) -> Iterator[tuple[_T | _U, _T | _U]]: + """Join two sequences on common attributes This is a semi-streaming operation. The LEFT sequence is fully evaluated and placed into memory. The RIGHT sequence is evaluated lazily and so can @@ -876,14 +992,15 @@ def join(leftkey, leftseq, rightkey, rightseq, d = groupby(leftkey, leftseq) - if left_default == no_default and right_default == no_default: + if _is_no_default(left_default) and _is_no_default(right_default): # Inner Join for item in rightseq: key = rightkey(item) if key in d: for left_match in d[key]: - yield (left_match, item) - elif left_default != no_default and right_default == no_default: + ret = (left_match, item) + yield cast("tuple[_T | _U, _T | _U]", ret) + elif not _is_no_default(left_default) and _is_no_default(right_default): # Right Join for item in rightseq: key = rightkey(item) @@ -891,16 +1008,15 @@ def join(leftkey, leftseq, rightkey, rightseq, for left_match in d[key]: yield (left_match, item) else: - yield (left_default, item) - elif right_default != no_default: + yield (cast(_U, left_default), item) + elif not _is_no_default(right_default): seen_keys = set() - seen = seen_keys.add - if left_default == no_default: + if _is_no_default(left_default): # Left Join for item in rightseq: key = rightkey(item) - seen(key) + seen_keys.add(key) if key in d: for left_match in d[key]: yield (left_match, item) @@ -908,20 +1024,24 @@ def join(leftkey, leftseq, rightkey, rightseq, # Full Join for item in rightseq: key = rightkey(item) - seen(key) + seen_keys.add(key) if key in d: for left_match in d[key]: yield (left_match, item) else: - yield (left_default, item) + yield (cast(_U, left_default), item) for key, matches in d.items(): if key not in seen_keys: for match in matches: - yield (match, right_default) + yield (match, cast(_U, right_default)) -def diff(*seqs, **kwargs): +def diff( + *seqs: Iterable[_T], + default: _T | NoDefaultType = no_default, + key: TransformOp[_T, Any] | None = None, +) -> Iterator[tuple[_T, ...]]: """ Return those items that differ between sequences >>> list(diff([1, 2, 3], [1, 2, 10, 100])) @@ -940,16 +1060,20 @@ def diff(*seqs, **kwargs): """ N = len(seqs) if N == 1 and isinstance(seqs[0], list): - seqs = seqs[0] - N = len(seqs) + all_seqs: Iterable[Iterable[_T]] = seqs[0] + N = len(list(all_seqs)) + else: + all_seqs = cast(Iterable[Iterable[_T]], seqs) if N < 2: raise TypeError('Too few sequences given (min 2 required)') - default = kwargs.get('default', no_default) - if default == no_default: - iters = zip(*seqs) + + if not _is_no_default(default): + iters = cast( + "Iterator[tuple[_T, ...]]", + zip_longest(*all_seqs, fillvalue=default), + ) else: - iters = zip_longest(*seqs, fillvalue=default) - key = kwargs.get('key', None) + iters = zip(*all_seqs) if key is None: for items in iters: if items.count(items[0]) != N: @@ -961,7 +1085,11 @@ def diff(*seqs, **kwargs): yield items -def topk(k, seq, key=None): +def topk( + k: int, + seq: Iterable[_T], + key: Predicate[_T] | Any | None = None, +) -> tuple[_T, ...]: """ Find the k largest elements of a sequence Operates lazily in ``n*log(k)`` time @@ -982,7 +1110,7 @@ def topk(k, seq, key=None): return tuple(heapq.nlargest(k, seq, key=key)) -def peek(seq): +def peek(seq: Iterable[_T]) -> tuple[_T, Iterator[_T]]: """ Retrieve the next element of a sequence Returns the first element and an iterable equivalent to the original @@ -1000,7 +1128,7 @@ def peek(seq): return item, itertools.chain((item,), iterator) -def peekn(n, seq): +def peekn(n: int, seq: Iterable[_T]) -> tuple[tuple[_T, ...], Iterator[_T]]: """ Retrieve the next n elements of a sequence Returns a tuple of the first n elements and an iterable equivalent @@ -1018,7 +1146,17 @@ def peekn(n, seq): return peeked, itertools.chain(iter(peeked), iterator) -def random_sample(prob, seq, random_state=None): +def _has_random(random_state: Any) -> TypeGuard[Randomable]: + if hasattr(random_state, 'random'): + return True + return False + + +def random_sample( + prob: float, + seq: Iterable[_T], + random_state: int | Randomable | None = None, +) -> Iterator[_T]: """ Return elements from a sequence with probability of prob Returns a lazy iterator of random items from seq. @@ -1050,8 +1188,6 @@ def random_sample(prob, seq, random_state=None): >>> list(random_sample(0.1, seq, random_state=randobj)) [7, 9, 19, 25, 30, 32, 34, 48, 59, 60, 81, 98] """ - if not hasattr(random_state, 'random'): - from random import Random - - random_state = Random(random_state) + if not _has_random(random_state): + random_state = Random(random_state) # noqa: S311 return filter(lambda _: random_state.random() < prob, seq) diff --git a/toolz/py.typed b/toolz/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/toolz/recipes.py b/toolz/recipes.py index 89de88db..84d8a5db 100644 --- a/toolz/recipes.py +++ b/toolz/recipes.py @@ -1,11 +1,23 @@ +from __future__ import annotations + import itertools -from .itertoolz import frequencies, pluck, getter +from typing import TYPE_CHECKING, Any + +from .itertoolz import frequencies, getter, pluck + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + from typing import TypeVar + + from .itertoolz import TransformOp + + _T = TypeVar('_T') __all__ = ('countby', 'partitionby') -def countby(key, seq): +def countby(key: Any, seq: Iterable[_T]) -> dict[_T, int]: """ Count elements of a collection by a key function >>> countby(len, ['cat', 'mouse', 'dog']) @@ -23,7 +35,9 @@ def countby(key, seq): return frequencies(map(key, seq)) -def partitionby(func, seq): +def partitionby( + func: TransformOp[_T, Any], seq: Iterable[_T] +) -> Iterator[tuple[_T, ...]]: """ Partition a sequence according to a function Partition `s` into a sequence of lists such that, when traversing diff --git a/toolz/sandbox/__init__.py b/toolz/sandbox/__init__.py index 0abda1cb..145c815c 100644 --- a/toolz/sandbox/__init__.py +++ b/toolz/sandbox/__init__.py @@ -1,2 +1,4 @@ from .core import EqualityHashKey, unzip from .parallel import fold + +__all__ = ['EqualityHashKey', 'fold', 'unzip'] diff --git a/toolz/sandbox/core.py b/toolz/sandbox/core.py index 55e09d74..5b03ac9f 100644 --- a/toolz/sandbox/core.py +++ b/toolz/sandbox/core.py @@ -1,10 +1,17 @@ -from toolz.itertoolz import getter, cons, pluck -from itertools import tee, starmap +from __future__ import annotations + +from itertools import starmap, tee +from typing import Any, Callable, Generic, Iterable, TypeVar + +from toolz.itertoolz import cons, getter, pluck + +_S = TypeVar('_S') +_T = TypeVar('_T') # See #166: https://github.com/pytoolz/toolz/issues/166 # See #173: https://github.com/pytoolz/toolz/pull/173 -class EqualityHashKey(object): +class EqualityHashKey(Generic[_S, _T]): """ Create a hash key that uses equality comparisons between items. This may be used to create hash keys for otherwise unhashable types: @@ -58,44 +65,47 @@ class EqualityHashKey(object): See Also: identity """ + __slots__ = ['item', 'key'] - _default_hashkey = '__default__hashkey__' + _default_hashkey: str = '__default__hashkey__' - def __init__(self, key, item): + def __init__(self, key: _S, item: _T) -> None: if key is None: - self.key = self._default_hashkey + self.key: str | Callable = self._default_hashkey elif not callable(key): self.key = getter(key) else: self.key = key self.item = item - def __hash__(self): - if self.key == self._default_hashkey: - val = self.key + def __hash__(self) -> int: + if not callable(self.key): + val = self._default_hashkey else: val = self.key(self.item) return hash(val) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: try: - return (self._default_hashkey == other._default_hashkey and - self.item == other.item) + return bool( + self._default_hashkey == other._default_hashkey + and self.item == other.item + ) except AttributeError: return False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def __str__(self): + def __str__(self) -> str: return '=%s=' % str(self.item) - def __repr__(self): + def __repr__(self) -> str: return '=%s=' % repr(self.item) # See issue #293: https://github.com/pytoolz/toolz/issues/239 -def unzip(seq): +def unzip(seq: Iterable[Any]) -> tuple[Any, ...]: """Inverse of ``zip`` >>> a, b = unzip([('a', 1), ('b', 2)]) @@ -124,7 +134,7 @@ def unzip(seq): try: first = tuple(next(seq)) except StopIteration: - return tuple() + return () # and create them niters = len(first) diff --git a/toolz/sandbox/parallel.py b/toolz/sandbox/parallel.py index 114077d2..05ed7c13 100644 --- a/toolz/sandbox/parallel.py +++ b/toolz/sandbox/parallel.py @@ -1,16 +1,36 @@ +from __future__ import annotations + import functools +from typing import TYPE_CHECKING, Callable, Iterable, Sequence, TypeVar, cast + from toolz.itertoolz import partition_all from toolz.utils import no_default +if TYPE_CHECKING: + from toolz.itertoolz import BinaryOp + from toolz.utils import NoDefaultType + +_T = TypeVar('_T') -def _reduce(func, seq, initial=None): + +def _reduce( + func: BinaryOp[_T], + seq: Iterable[_T], + initial: _T | None = None, +) -> _T: if initial is None: return functools.reduce(func, seq) - else: - return functools.reduce(func, seq, initial) + return functools.reduce(func, seq, initial) -def fold(binop, seq, default=no_default, map=map, chunksize=128, combine=None): +def fold( + binop: BinaryOp[_T], + seq: Sequence[_T], + default: _T | NoDefaultType = no_default, + map: Callable = map, + chunksize: int = 128, + combine: BinaryOp[_T] | None = None, +) -> _T: """ Reduce without guarantee of ordered reduction. @@ -50,7 +70,8 @@ def fold(binop, seq, default=no_default, map=map, chunksize=128, combine=None): >>> fold(add, [1, 2, 3, 4], chunksize=2, map=map) 10 """ - assert chunksize > 1 + # assert chunksize > 1 + chunksize = max(chunksize, 1) if combine is None: combine = binop @@ -59,17 +80,17 @@ def fold(binop, seq, default=no_default, map=map, chunksize=128, combine=None): # Evaluate sequence in chunks via map if default == no_default: - results = map( - functools.partial(_reduce, binop), - chunks) + results = map(functools.partial(_reduce, binop), chunks) else: results = map( functools.partial(_reduce, binop, initial=default), - chunks) + chunks, + ) results = list(results) # TODO: Support complete laziness if len(results) == 1: # Return completed result - return results[0] + res = results[0] else: # Recurse to reaggregate intermediate results - return fold(combine, results, map=map, chunksize=chunksize) + res = fold(combine, results, map=map, chunksize=chunksize) + return cast(_T, res) diff --git a/toolz/sandbox/tests/test_core.py b/toolz/sandbox/tests/test_core.py index d2a5ed08..46f85832 100644 --- a/toolz/sandbox/tests/test_core.py +++ b/toolz/sandbox/tests/test_core.py @@ -1,7 +1,9 @@ -from toolz import curry, unique, first, take -from toolz.sandbox.core import EqualityHashKey, unzip from itertools import count, repeat +from toolz import curry, first, take, unique +from toolz.sandbox.core import EqualityHashKey, unzip + + def test_EqualityHashKey_default_key(): EqualityHashDefault = curry(EqualityHashKey, None) L1 = [1] @@ -31,8 +33,8 @@ def test_EqualityHashKey_default_key(): assert repr(E1) == '=[1]=' assert E1 != E2 assert not (E1 == E2) - assert E1 == EqualityHashDefault(L1) - assert not (E1 != EqualityHashDefault(L1)) + assert EqualityHashDefault(L1) == E1 + assert not (EqualityHashDefault(L1) != E1) assert E1 != L1 assert not (E1 == L1) diff --git a/toolz/sandbox/tests/test_parallel.py b/toolz/sandbox/tests/test_parallel.py index 7a455937..d654fe23 100644 --- a/toolz/sandbox/tests/test_parallel.py +++ b/toolz/sandbox/tests/test_parallel.py @@ -1,12 +1,14 @@ -from toolz.sandbox.parallel import fold -from toolz import reduce +from multiprocessing import Pool from operator import add from pickle import dumps, loads -from multiprocessing import Pool +from toolz import reduce +from toolz.sandbox.parallel import fold + +from toolz.utils import no_default # is comparison will fail between this and no_default -no_default2 = loads(dumps('__no__default__')) +no_default2 = loads(dumps(no_default)) def test_fold(): diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index 27907b9e..c3043142 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -1,24 +1,27 @@ import itertools -from itertools import starmap -from toolz.utils import raises from functools import partial -from random import Random -from pickle import dumps, loads -from toolz.itertoolz import (remove, groupby, merge_sorted, - concat, concatv, interleave, unique, - isiterable, getter, - mapcat, isdistinct, first, second, - nth, take, tail, drop, interpose, get, - rest, last, cons, frequencies, - reduceby, iterate, accumulate, - sliding_window, count, partition, - partition_all, take_nth, pluck, join, - diff, topk, peek, peekn, random_sample) +from itertools import starmap from operator import add, mul +from pickle import dumps, loads +from random import Random + +from toolz.itertoolz import ( + remove, groupby, merge_sorted, + concat, concatv, interleave, unique, + isiterable, getter, + mapcat, isdistinct, first, second, + nth, take, tail, drop, interpose, get, + rest, last, cons, frequencies, + reduceby, iterate, accumulate, + sliding_window, count, partition, + partition_all, take_nth, pluck, join, + diff, topk, peek, peekn, random_sample +) +from toolz.utils import raises, no_default # is comparison will fail between this and no_default -no_default2 = loads(dumps('__no__default__')) +no_default2 = loads(dumps(no_default)) def identity(x): @@ -555,9 +558,9 @@ def test_random_sample(): assert list(random_sample(prob=1, seq=alist, random_state=2016)) == alist - mk_rsample = lambda rs=1: list(random_sample(prob=0.1, - seq=alist, - random_state=rs)) + def mk_rsample(rs=1): + return list(random_sample(prob=0.1, seq=alist, random_state=rs)) + rsample1 = mk_rsample() assert rsample1 == mk_rsample() @@ -569,6 +572,6 @@ def test_random_sample(): assert mk_rsample(hash(object)) == mk_rsample(hash(object)) assert mk_rsample(hash(object)) != mk_rsample(hash(object())) - assert mk_rsample(b"a") == mk_rsample(u"a") + assert mk_rsample(b"a") == mk_rsample("a") assert raises(TypeError, lambda: mk_rsample([])) diff --git a/toolz/tests/test_recipes.py b/toolz/tests/test_recipes.py index a45c054f..a7719f14 100644 --- a/toolz/tests/test_recipes.py +++ b/toolz/tests/test_recipes.py @@ -1,4 +1,4 @@ -from toolz import first, identity, countby, partitionby +from toolz import countby, first, identity, partitionby def iseven(x): diff --git a/toolz/tests/test_serialization.py b/toolz/tests/test_serialization.py index 5f432ae5..11665339 100644 --- a/toolz/tests/test_serialization.py +++ b/toolz/tests/test_serialization.py @@ -1,7 +1,8 @@ -from toolz import * +import pickle + import toolz import toolz.curried -import pickle +from toolz import complement, compose, curry, juxt from toolz.utils import raises @@ -66,7 +67,7 @@ def test_curried_exceptions(): @toolz.curry -class GlobalCurried(object): +class GlobalCurried: def __init__(self, x, y): self.x = x self.y = y @@ -83,7 +84,7 @@ def __reduce__(self): return GlobalCurried, (self.x, self.y) @toolz.curry - class NestedCurried(object): + class NestedCurried: def __init__(self, x, y): self.x = x self.y = y @@ -99,7 +100,7 @@ def __reduce__(self): """Allow us to serialize instances of NestedCurried""" return GlobalCurried.NestedCurried, (self.x, self.y) - class Nested(object): + class Nested: def __init__(self, x, y): self.x = x self.y = y @@ -185,8 +186,9 @@ def preserves_identity(obj): def test_curried_bad_qualname(): + @toolz.curry - class Bad(object): + class Bad: __qualname__ = 'toolz.functoolz.not.a.valid.path' assert raises(pickle.PicklingError, lambda: pickle.dumps(Bad)) diff --git a/toolz/tests/test_signatures.py b/toolz/tests/test_signatures.py index 03b9293c..04b1e998 100644 --- a/toolz/tests/test_signatures.py +++ b/toolz/tests/test_signatures.py @@ -1,13 +1,16 @@ import functools + import toolz._signatures as _sigs -from toolz._signatures import builtins, _is_valid_args, _is_partial_args +from toolz._signatures import _is_partial_args, _is_valid_args, builtins def test_is_valid(check_valid=_is_valid_args, incomplete=False): orig_check_valid = check_valid - check_valid = lambda func, *args, **kwargs: orig_check_valid(func, args, kwargs) + check_valid = ( # noqa: E731 + lambda func, *args, **kwargs: orig_check_valid(func, args, kwargs) + ) - assert check_valid(lambda x: None) is None + assert check_valid(lambda _: None) is None f = builtins.abs assert check_valid(f) is incomplete diff --git a/toolz/tests/test_utils.py b/toolz/tests/test_utils.py index e7a0eaba..2d751236 100644 --- a/toolz/tests/test_utils.py +++ b/toolz/tests/test_utils.py @@ -1,6 +1,6 @@ from toolz.utils import raises -def test_raises(): +def test_raises() -> None: assert raises(ZeroDivisionError, lambda: 1 / 0) assert not raises(ZeroDivisionError, lambda: 1) diff --git a/toolz/utils.py b/toolz/utils.py index 1002c464..a12a1376 100644 --- a/toolz/utils.py +++ b/toolz/utils.py @@ -1,4 +1,10 @@ -def raises(err, lamda): +from __future__ import annotations + +from enum import Enum +from typing import Any, Callable + + +def raises(err: type[Exception], lamda: Callable[[], Any]) -> bool: try: lamda() return False @@ -6,4 +12,21 @@ def raises(err, lamda): return True -no_default = '__no__default__' +class NoDefaultType(Enum): + no_default = '__no_default__' + + +no_default = NoDefaultType.no_default + +# no_default = '__no__default__' +# NoDefaultType = Literal['__no__default__'] + + +class NoPadType(Enum): + no_pad = '__no_pad__' + + +no_pad = NoPadType.no_pad + +# no_pad = '__no__pad__' +# NoPadType = Literal['__no__pad__'] diff --git a/versioneer.py b/versioneer.py deleted file mode 100644 index 509b221a..00000000 --- a/versioneer.py +++ /dev/null @@ -1,1822 +0,0 @@ - -# Version: 0.18 - -"""The Versioneer - like a rocketeer, but for versions. - -The Versioneer -============== - -* like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer -* Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based -python projects. The goal is to remove the tedious and error-prone "update -the embedded version string" step from your release process. Making a new -release should be as easy as recording a new tag in your version-control -system, and maybe making new tarballs. - - -## Quick Install - -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results - -## Version Identifiers - -Source trees come from a variety of places: - -* a version-control system checkout (mostly used by developers) -* a nightly tarball, produced by build automation -* a snapshot tarball, produced by a web-based VCS browser, like github's - "tarball from tag" feature -* a release tarball, produced by "setup.py sdist", distributed through PyPI - -Within each source tree, the version identifier (either a string or a number, -this tool is format-agnostic) can come from a variety of places: - -* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows - about recent "tags" and an absolute revision-id -* the name of the directory into which the tarball was unpacked -* an expanded VCS keyword ($Id$, etc) -* a `_version.py` created by some earlier build step - -For released software, the version identifier is closely related to a VCS -tag. Some projects use tag names that include more than just the version -string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool -needs to strip the tag prefix to extract the version identifier. For -unreleased software (between tags), the version identifier should provide -enough information to help developers recreate the same tree, while also -giving them an idea of roughly how old the tree is (after version 1.2, before -version 1.3). Many VCS systems can report a description that captures this, -for example `git describe --tags --dirty --always` reports things like -"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the -0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. - -The version identifier is used for multiple purposes: - -* to allow the module to self-identify its version: `myproject.__version__` -* to choose a name and prefix for a 'setup.py sdist' tarball - -## Theory of Operation - -Versioneer works by adding a special `_version.py` file into your source -tree, where your `__init__.py` can import it. This `_version.py` knows how to -dynamically ask the VCS tool for version information at import time. - -`_version.py` also contains `$Revision$` markers, and the installation -process marks `_version.py` to have this marker rewritten with a tag name -during the `git archive` command. As a result, generated tarballs will -contain enough information to get the proper version. - -To allow `setup.py` to compute a version too, a `versioneer.py` is added to -the top level of your source tree, next to `setup.py` and the `setup.cfg` -that configures it. This overrides several distutils/setuptools commands to -compute the version when invoked, and changes `setup.py build` and `setup.py -sdist` to replace `_version.py` with a small static file that contains just -the generated version data. - -## Installation - -See [INSTALL.md](./INSTALL.md) for detailed installation instructions. - -## Version-String Flavors - -Code which uses Versioneer can learn about its version string at runtime by -importing `_version` from your main `__init__.py` file and running the -`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can -import the top-level `versioneer.py` and run `get_versions()`. - -Both functions return a dictionary with different flavors of version -information: - -* `['version']`: A condensed version string, rendered using the selected - style. This is the most commonly used value for the project's version - string. The default "pep440" style yields strings like `0.11`, - `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section - below for alternative styles. - -* `['full-revisionid']`: detailed revision identifier. For Git, this is the - full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". - -* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the - commit date in ISO 8601 format. This will be None if the date is not - available. - -* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that - this is only accurate if run in a VCS checkout, otherwise it is likely to - be False or None - -* `['error']`: if the version string could not be computed, this will be set - to a string describing the problem, otherwise it will be None. It may be - useful to throw an exception in setup.py if this is set, to avoid e.g. - creating tarballs with a version string of "unknown". - -Some variants are more useful than others. Including `full-revisionid` in a -bug report should allow developers to reconstruct the exact code being tested -(or indicate the presence of local changes that should be shared with the -developers). `version` is suitable for display in an "about" box or a CLI -`--version` output: it can be easily compared against release notes and lists -of bugs fixed in various releases. - -The installer adds the following text to your `__init__.py` to place a basic -version in `YOURPROJECT.__version__`: - - from ._version import get_versions - __version__ = get_versions()['version'] - del get_versions - -## Styles - -The setup.cfg `style=` configuration controls how the VCS information is -rendered into a version string. - -The default style, "pep440", produces a PEP440-compliant string, equal to the -un-prefixed tag name for actual releases, and containing an additional "local -version" section with more detail for in-between builds. For Git, this is -TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags ---dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the -tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and -that this commit is two revisions ("+2") beyond the "0.11" tag. For released -software (exactly equal to a known tag), the identifier will only contain the -stripped tag, e.g. "0.11". - -Other styles are available. See [details.md](details.md) in the Versioneer -source tree for descriptions. - -## Debugging - -Versioneer tries to avoid fatal errors: if something goes wrong, it will tend -to return a version of "0+unknown". To investigate the problem, run `setup.py -version`, which will run the version-lookup code in a verbose mode, and will -display the full contents of `get_versions()` (including the `error` string, -which may help identify what went wrong). - -## Known Limitations - -Some situations are known to cause problems for Versioneer. This details the -most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). - -### Subprojects - -Versioneer has limited support for source trees in which `setup.py` is not in -the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are -two common reasons why `setup.py` might not be in the root: - -* Source trees which contain multiple subprojects, such as - [Buildbot](https://github.com/buildbot/buildbot), which contains both - "master" and "slave" subprojects, each with their own `setup.py`, - `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI - distributions (and upload multiple independently-installable tarballs). -* Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other languages) in subdirectories. - -Versioneer will look for `.git` in parent directories, and most operations -should get the right version string. However `pip` and `setuptools` have bugs -and implementation details which frequently cause `pip install .` from a -subproject directory to fail to find a correct version string (so it usually -defaults to `0+unknown`). - -`pip install --editable .` should work correctly. `setup.py install` might -work too. - -Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in -some later version. - -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking -this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the -issue from the Versioneer side in more detail. -[pip PR#3176](https://github.com/pypa/pip/pull/3176) and -[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve -pip to let Versioneer work correctly. - -Versioneer-0.16 and earlier only looked for a `.git` directory next to the -`setup.cfg`, so subprojects were completely unsupported with those releases. - -### Editable installs with setuptools <= 18.5 - -`setup.py develop` and `pip install --editable .` allow you to install a -project into a virtualenv once, then continue editing the source code (and -test) without re-installing after every change. - -"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a -convenient way to specify executable scripts that should be installed along -with the python package. - -These both work as expected when using modern setuptools. When using -setuptools-18.5 or earlier, however, certain operations will cause -`pkg_resources.DistributionNotFound` errors when running the entrypoint -script, which must be resolved by re-installing the package. This happens -when the install happens with one version, then the egg_info data is -regenerated while a different version is checked out. Many setup.py commands -cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into -a different virtualenv), so this can be surprising. - -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes -this one, but upgrading to a newer version of setuptools should probably -resolve it. - -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - - -## Updating Versioneer - -To upgrade your project to a new release of Versioneer, do the following: - -* install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace - `SRC/_version.py` -* commit any changed files - -## Future Directions - -This tool is designed to make it easily extended to other version-control -systems: all VCS-specific components are in separate directories like -src/git/ . The top-level `versioneer.py` script is assembled from these -components by running make-versioneer.py . In the future, make-versioneer.py -will take a VCS name as an argument, and will construct a version of -`versioneer.py` that is specific to the given VCS. It might also take the -configuration arguments that are currently provided manually during -installation by editing setup.py . Alternatively, it might go the other -direction and include code from all supported VCS systems, reducing the -number of intermediate scripts. - - -## License - -To make Versioneer easier to embed, all its code is dedicated to the public -domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . - -""" - -from __future__ import print_function -try: - import configparser -except ImportError: - import ConfigParser as configparser -import errno -import json -import os -import re -import subprocess -import sys - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_root(): - """Get the project root directory. - - We require that all commands are run from the project root, i.e. the - directory that contains setup.py, setup.cfg, and versioneer.py . - """ - root = os.path.realpath(os.path.abspath(os.getcwd())) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - # allow 'python path/to/setup.py COMMAND' - root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") - raise VersioneerBadRootError(err) - try: - # Certain runtime workflows (setup.py install/develop in a setuptools - # tree) execute all dependencies in a single python process, so - # "versioneer" may be imported multiple times, and python's shared - # module-import table will cache the first one. So we can't use - # os.path.dirname(__file__), as that will find whichever - # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) - vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py)) - except NameError: - pass - return root - - -def get_config_from_root(root): - """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or - # configparser.NoSectionError (if it lacks a [versioneer] section), or - # configparser.NoOptionError (if it lacks "VCS="). See the docstring at - # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.ConfigParser() - - parser.read(setup_cfg) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None - cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): - cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -# these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -LONG_VERSION_PY['git'] = ''' -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" - git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" - git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "%(STYLE)s" - cfg.tag_prefix = "%(TAG_PREFIX)s" - cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" - cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %%s" %% dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %%s" %% (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %%s (error)" %% dispcmd) - print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %%s but none started with prefix %%s" %% - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %%d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%%s', no digits" %% ",".join(refs - tags)) - if verbose: - print("likely tags: %%s" %% ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %%s" %% r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %%s not under git control" %% root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%%s'" - %% describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%%s' doesn't start with prefix '%%s'" - print(fmt %% (full_tag, tag_prefix)) - pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" - %% (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%%s'" %% style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} -''' - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def do_vcs_install(manifest_in, versionfile_source, ipy): - """Git-specific installation logic for Versioneer. - - For Git, this means creating/changing .gitattributes to mark _version.py - for export-subst keyword substitution. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] - if ipy: - files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) - present = False - try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except EnvironmentError: - pass - if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() - files.append(".gitattributes") - run_command(GITS, ["add", "--"] + files) - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from -# revision-control system data, or from the parent directory name of an -# unpacked source archive. Distribution tarballs contain a pre-generated copy -# of this file. - -import json - -version_json = ''' -%s -''' # END VERSION_JSON - - -def get_versions(): - return json.loads(version_json) -""" - - -def versions_from_file(filename): - """Try to determine the version from _version.py if present.""" - try: - with open(filename) as f: - contents = f.read() - except EnvironmentError: - raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - raise NotThisMethod("no version_json in _version.py") - return json.loads(mo.group(1)) - - -def write_to_version_file(filename, versions): - """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) - with open(filename, "w") as f: - f.write(SHORT_VERSION_PY % contents) - - print("set %s to '%s'" % (filename, versions["version"])) - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -class VersioneerBadRootError(Exception): - """The project root directory is unknown or missing key files.""" - - -def get_versions(verbose=False): - """Get the project version from whatever source is available. - - Returns dict with two keys: 'version' and 'full'. - """ - if "versioneer" in sys.modules: - # see the discussion in cmdclass.py:get_cmdclass() - del sys.modules["versioneer"] - - root = get_root() - cfg = get_config_from_root(root) - - assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" - handlers = HANDLERS.get(cfg.VCS) - assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" - assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" - - versionfile_abs = os.path.join(root, cfg.versionfile_source) - - # extract version from first of: _version.py, VCS command (e.g. 'git - # describe'), parentdir. This is meant to work for developers using a - # source checkout, for users of a tarball created by 'setup.py sdist', - # and for users of a tarball/zipball created by 'git archive' or github's - # download-from-tag feature or the equivalent in other VCSes. - - get_keywords_f = handlers.get("get_keywords") - from_keywords_f = handlers.get("keywords") - if get_keywords_f and from_keywords_f: - try: - keywords = get_keywords_f(versionfile_abs) - ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) - if verbose: - print("got version from expanded keyword %s" % ver) - return ver - except NotThisMethod: - pass - - try: - ver = versions_from_file(versionfile_abs) - if verbose: - print("got version from file %s %s" % (versionfile_abs, ver)) - return ver - except NotThisMethod: - pass - - from_vcs_f = handlers.get("pieces_from_vcs") - if from_vcs_f: - try: - pieces = from_vcs_f(cfg.tag_prefix, root, verbose) - ver = render(pieces, cfg.style) - if verbose: - print("got version from VCS %s" % ver) - return ver - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - if verbose: - print("got version from parentdir %s" % ver) - return ver - except NotThisMethod: - pass - - if verbose: - print("unable to compute version") - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} - - -def get_version(): - """Get the short version string for this project.""" - return get_versions()["version"] - - -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" - if "versioneer" in sys.modules: - del sys.modules["versioneer"] - # this fixes the "python setup.py develop" case (also 'install' and - # 'easy_install .'), in which subdependencies of the main project are - # built (using setup.py bdist_egg) in the same python process. Assume - # a main project A and a dependency B, which use different versions - # of Versioneer. A's setup.py imports A's Versioneer, leaving it in - # sys.modules by the time B's setup.py is executed, causing B to run - # with the wrong versioneer. Setuptools wraps the sub-dep builds in a - # sandbox that restores sys.modules to it's pre-build state, so the - # parent is protected against the child's "import versioneer". By - # removing ourselves from sys.modules here, before the child build - # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 - - cmds = {} - - # we add "version" to both distutils and setuptools - from distutils.core import Command - - class cmd_version(Command): - description = "report generated version string" - user_options = [] - boolean_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - vers = get_versions(verbose=True) - print("Version: %s" % vers["version"]) - print(" full-revisionid: %s" % vers.get("full-revisionid")) - print(" dirty: %s" % vers.get("dirty")) - print(" date: %s" % vers.get("date")) - if vers["error"]: - print(" error: %s" % vers["error"]) - cmds["version"] = cmd_version - - # we override "build_py" in both distutils and setuptools - # - # most invocation pathways end up running build_py: - # distutils/build -> build_py - # distutils/install -> distutils/build ->.. - # setuptools/bdist_wheel -> distutils/install ->.. - # setuptools/bdist_egg -> distutils/install_lib -> build_py - # setuptools/install -> bdist_egg ->.. - # setuptools/develop -> ? - # pip install: - # copies source tree to a tempdir before running egg_info/etc - # if .git isn't copied too, 'git describe' will fail - # then does setup.py bdist_wheel, or sometimes setup.py install - # setup.py egg_info -> ? - - # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py - else: - from distutils.command.build_py import build_py as _build_py - - class cmd_build_py(_build_py): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_py.run(self) - # now locate _version.py in the new build/ directory and replace - # it with an updated value - if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - cmds["build_py"] = cmd_build_py - - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe - # nczeczulin reports that py2exe won't like the pep440-style string - # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. - # setup(console=[{ - # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION - # "product_version": versioneer.get_version(), - # ... - - class cmd_build_exe(_build_exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _build_exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["build_exe"] = cmd_build_exe - del cmds["build_py"] - - if 'py2exe' in sys.modules: # py2exe enabled? - try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 - except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 - - class cmd_py2exe(_py2exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _py2exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["py2exe"] = cmd_py2exe - - # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist - else: - from distutils.command.sdist import sdist as _sdist - - class cmd_sdist(_sdist): - def run(self): - versions = get_versions() - self._versioneer_generated_versions = versions - # unless we update this, the command will keep using the old - # version - self.distribution.metadata.version = versions["version"] - return _sdist.run(self) - - def make_release_tree(self, base_dir, files): - root = get_root() - cfg = get_config_from_root(root) - _sdist.make_release_tree(self, base_dir, files) - # now locate _version.py in the new base_dir directory - # (remembering that it may be a hardlink) and replace it with an - # updated value - target_versionfile = os.path.join(base_dir, cfg.versionfile_source) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) - cmds["sdist"] = cmd_sdist - - return cmds - - -CONFIG_ERROR = """ -setup.cfg is missing the necessary Versioneer configuration. You need -a section like: - - [versioneer] - VCS = git - style = pep440 - versionfile_source = src/myproject/_version.py - versionfile_build = myproject/_version.py - tag_prefix = - parentdir_prefix = myproject- - -You will also need to edit your setup.py to use the results: - - import versioneer - setup(version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), ...) - -Please read the docstring in ./versioneer.py for configuration instructions, -edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. -""" - -SAMPLE_CONFIG = """ -# See the docstring in versioneer.py for instructions. Note that you must -# re-run 'versioneer.py setup' after changing this section, and commit the -# resulting files. - -[versioneer] -#VCS = git -#style = pep440 -#versionfile_source = -#versionfile_build = -#tag_prefix = -#parentdir_prefix = - -""" - -INIT_PY_SNIPPET = """ -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions -""" - - -def do_setup(): - """Main VCS-independent setup function for installing Versioneer.""" - root = get_root() - try: - cfg = get_config_from_root(root) - except (EnvironmentError, configparser.NoSectionError, - configparser.NoOptionError) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) - with open(os.path.join(root, "setup.cfg"), "a") as f: - f.write(SAMPLE_CONFIG) - print(CONFIG_ERROR, file=sys.stderr) - return 1 - - print(" creating %s" % cfg.versionfile_source) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") - if os.path.exists(ipy): - try: - with open(ipy, "r") as f: - old = f.read() - except EnvironmentError: - old = "" - if INIT_PY_SNIPPET not in old: - print(" appending to %s" % ipy) - with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) - else: - print(" %s unmodified" % ipy) - else: - print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except EnvironmentError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") - - # Make VCS-specific changes. For git, this means creating/changing - # .gitattributes to mark _version.py for export-subst keyword - # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) - return 0 - - -def scan_setup_py(): - """Validate the contents of setup.py against Versioneer's expectations.""" - found = set() - setters = False - errors = 0 - with open("setup.py", "r") as f: - for line in f.readlines(): - if "import versioneer" in line: - found.add("import") - if "versioneer.get_cmdclass()" in line: - found.add("cmdclass") - if "versioneer.get_version()" in line: - found.add("get_version") - if "versioneer.VCS" in line: - setters = True - if "versioneer.versionfile_source" in line: - setters = True - if len(found) != 3: - print("") - print("Your setup.py appears to be missing some important items") - print("(but I might be wrong). Please make sure it has something") - print("roughly like the following:") - print("") - print(" import versioneer") - print(" setup( version=versioneer.get_version(),") - print(" cmdclass=versioneer.get_cmdclass(), ...)") - print("") - errors += 1 - if setters: - print("You should remove lines like 'versioneer.VCS = ' and") - print("'versioneer.versionfile_source = ' . This configuration") - print("now lives in setup.cfg, and should be removed from setup.py") - print("") - errors += 1 - return errors - - -if __name__ == "__main__": - cmd = sys.argv[1] - if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1)