diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index 667d139a..30ce51aa 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -57,7 +57,7 @@ jobs: cache: pip cache-dependency-path: "**/pyproject.toml" - - uses: ts-graphviz/setup-graphviz@v1 + - uses: ts-graphviz/setup-graphviz@v2 with: macos-skip-brew-update: true @@ -85,12 +85,13 @@ jobs: steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 + with: { python-version: 3.x } + - run: python -m pip install pre-commit && python -m pip freeze --local + - uses: actions/cache@v4 with: - python-version: "3.x" - + path: ~/.cache/pre-commit + key: pre-commit|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }} - name: Force recreation of pre-commit virtual environment for mypy if: github.event_name == 'schedule' # Comment this line to run on a PR - run: gh cache list -L 999 | cut -f2 | grep pre-commit | xargs -I{} gh cache delete "{}" || true - env: { GH_TOKEN: "${{ github.token }}" } - - - uses: pre-commit/action@v3.0.0 + run: pre-commit clean + - run: pre-commit run --all-files --color=always --show-diff-on-failure diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7aa04387..062450d3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: - xarray args: [] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.14 + rev: v0.2.1 hooks: - id: ruff - id: ruff-format diff --git a/doc/api.rst b/doc/api.rst index e2216653..b59904ed 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -11,6 +11,7 @@ Top-level classes and functions configure Computer Key + KeySeq Quantity .. autofunction:: configure @@ -259,15 +260,19 @@ Top-level classes and functions Keys can also be manipulated using some of the Python arithmetic operators: - - :py:`+`: same as :meth:`.add_tag`: + - :py:`+`: and :py:`-`: manipulate :attr:`.tag`, same as :meth:`.add_tag` and :meth:`.remove_tag` respectively: - >>> k1 = Key("foo", "abc") + >>> k1 = Key("foo", "abc", "bar+baz+qux") >>> k1 - - >>> k1 + "tag" - + + >>> k2 + "newtag" + + >>> k1 - "baz" + + >>> k1 - ("bar", "baz") + - - :py:`*` with a single string, an iterable of strings, or another Key: similar to :meth:`.append` and :meth:`.product`: + - :py:`*` and :py:`/`: manipulate :attr:`dims`, similar to :meth:`.append`/:attr:`.product` and :attr:`.drop`, respectively: >>> k1 * "d" @@ -276,8 +281,6 @@ Top-level classes and functions >>> k1 * Key("bar", "ghi") - - :py:`/` with a single string or iterable of strings: similar to :meth:`drop`: - >>> k1 / "a" >>> k1 / ("a", "c") @@ -285,6 +288,72 @@ Top-level classes and functions >>> k1 / Key("baz", "cde") +.. autoclass:: genno.KeySeq + :members: + + When preparing chains or complicated graphs of computations, it can be useful to use a sequence or set of similar keys to refer to the intermediate steps. + The :class:`.KeySeq` class is provided for this purpose. + It supports several ways to create related keys starting from a *base key*: + + >>> ks = KeySeq("foo:x-y-z:bar") + + One may: + + - Use item access syntax: + + >>> ks["a"] + + >>> ks["b"] + + + - Use the Python built-in :func:`.next`. + This always returns the next key in a sequence of integers, starting with :py:`0` and continuing from the *highest previously created Key*: + + >>> next(ks) + + + # Skip some values + >>> ks[5] + + + # next() continues from the highest + >>> next(ks) + + + - Treat the KeySeq as callable, optionally with any value that has a :class:`.str` representation: + + >>> ks("c") + + + # Same as next() + >>> ks() + + + - Access the most recently generated item: + + >>> ks.prev + + + - Access the base Key or its properties: + + >>> ks.base + + >>> ks.name + "foo" + + - Access a :class:`dict` of all previously-created keys. + Because :class:`dict` is order-preserving, the order of keys and values reflects the order in which they were created: + + >>> tuple(ks.keys) + ("a", "b", 0, 5, 6, "a", 7) + + The same Python arithmetic operators usable with Key are usable with KeySeq; they return a new KeySeq with a different :attr:`~.KeySeq.base`: + + >>> ks * "w" + + >>> ks / ("x", "z") + + .. autoclass:: genno.Quantity :members: :inherited-members: pipe, shape, size diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst index 15ef4aaa..a69a37d2 100644 --- a/doc/whatsnew.rst +++ b/doc/whatsnew.rst @@ -1,8 +1,11 @@ What's new ********** -.. Next release -.. ============ +Next release +============ + +- Add :class:`.KeySeq` class for creating sequences or sets of similar :class:`Keys <.Key>` (:pull:`126`). +- Add :meth:`.Key.remove_tag` method and support for :py:`k - "foo"` syntax for removing tags from :class:`.Key` (:pull:`126`). v1.23.1 (2024-02-01) ==================== diff --git a/genno/__init__.py b/genno/__init__.py index a2d28969..748078b5 100644 --- a/genno/__init__.py +++ b/genno/__init__.py @@ -4,7 +4,7 @@ from .config import configure from .core.computer import Computer from .core.exceptions import ComputationError, KeyExistsError, MissingKeyError -from .core.key import Key +from .core.key import Key, KeySeq from .core.operator import Operator from .core.quantity import Quantity @@ -12,6 +12,7 @@ "ComputationError", "Computer", "Key", + "KeySeq", "KeyExistsError", "MissingKeyError", "Operator", diff --git a/genno/core/key.py b/genno/core/key.py index abfb1678..d928d5fb 100644 --- a/genno/core/key.py +++ b/genno/core/key.py @@ -2,7 +2,20 @@ import re from functools import partial, singledispatch from itertools import chain, compress -from typing import Callable, Generator, Iterable, Iterator, Optional, Tuple, Union +from types import MappingProxyType +from typing import ( + Callable, + Dict, + Generator, + Hashable, + Iterable, + Iterator, + Optional, + Sequence, + SupportsInt, + Tuple, + Union, +) from warnings import warn from genno.core.quantity import Quantity @@ -184,33 +197,37 @@ def product(cls, new_name: str, *keys, tag: Optional[str] = None) -> "Key": # Return new key. Use dict to keep only unique *dims*, in same order return cls(new_name, dict.fromkeys(dims).keys()).add_tag(tag) - def __add__(self, other) -> "Key": - if isinstance(other, str): - return self.add_tag(other) - else: + def __add__(self, other: str) -> "Key": + if not isinstance(other, str): raise TypeError(type(other)) + return self.add_tag(other) + + def __sub__(self, other: Union[str, Iterable[str]]) -> "Key": + return self.remove_tag(*((other,) if isinstance(other, str) else other)) - def __mul__(self, other) -> "Key": + def __mul__(self, other: Union[str, "Key", Sequence[str]]) -> "Key": if isinstance(other, str): - return self.append(other) + other_dims: Sequence[str] = (other,) + elif isinstance(other, Key): + other_dims = other.dims + elif isinstance(other, Sequence): + other_dims = other else: - # Key or iterable of dims - other_dims = getattr(other, "dims", other) - try: - return self.append(*other_dims) - except Exception: - raise TypeError(type(other)) - - def __truediv__(self, other) -> "Key": + raise TypeError(type(other)) + + return self.append(*other_dims) + + def __truediv__(self, other: Union[str, "Key", Sequence[str]]) -> "Key": if isinstance(other, str): - return self.drop(other) + other_dims: Sequence[str] = (other,) + elif isinstance(other, Key): + other_dims = other.dims + elif isinstance(other, Sequence): + other_dims = other else: - # Key or iterable of dims - other_dims = getattr(other, "dims", other) - try: - return self.drop(*other_dims) - except Exception: - raise TypeError(type(other)) + raise TypeError(type(other)) + + return self.drop(*other_dims) def __repr__(self) -> str: """Representation of the Key, e.g. '.""" @@ -295,7 +312,7 @@ def append(self, *dims: str) -> "Key": """Return a new Key with additional dimensions `dims`.""" return Key(self._name, list(self._dims) + list(dims), self._tag, _fast=True) - def add_tag(self, tag) -> "Key": + def add_tag(self, tag: Optional[str]) -> "Key": """Return a new Key with `tag` appended.""" return Key( self._name, self._dims, "+".join(filter(None, [self._tag, tag])), _fast=True @@ -312,6 +329,20 @@ def iter_sums(self) -> Generator[Tuple["Key", Callable, "Key"], None, None]: self, ) + def remove_tag(self, *tags: str) -> "Key": + """Return a key with any of `tags` dropped. + + Raises + ------ + ValueError + If none of `tags` are in :attr:`.tags`. + """ + new_tags = tuple(filter(lambda t: t not in tags, (self.tag or "").split("+"))) + new_tag = "+".join(new_tags) if new_tags else None + if new_tag == self.tag: + raise ValueError(f"No existing tags {tags!r} to remove") + return Key(self._name, self._dims, new_tag, _fast=True) + @_name_dims_tag.register def _(value: Key): @@ -319,6 +350,79 @@ def _(value: Key): return value._name, value._dims, value._tag +class KeySeq: + """Utility class for generating similar :class:`Keys <.Key>`.""" + + #: Base :class:`.Key` of the sequence. + base: Key + + # Keys that have been created. + _keys: Dict[Hashable, Key] + + def __init__(self, *args, **kwargs): + self.base = Key(*args, **kwargs) + self._keys = {} + + def _next_int_tag(self) -> int: + return max([-1] + [t for t in self._keys if isinstance(t, int)]) + 1 + + def __next__(self) -> Key: + return self[self._next_int_tag()] + + def __call__(self, value: Optional[Hashable] = None) -> Key: + return next(self) if value is None else self[value] + + def __getitem__(self, value: Hashable) -> Key: + tag = int(value) if isinstance(value, SupportsInt) else str(value) + result = self._keys[tag] = self.base + str(tag) + return result + + def __repr__(self) -> str: + return f"" + + @property + def keys(self) -> MappingProxyType: + """Read-only view of previously-created :class:`Keys <.Key>`. + + In the form of a :class:`dict` mapping tags (:class:`int` or :class:`str`) to + :class:`.Key` values. + """ + return MappingProxyType(self._keys) + + @property + def prev(self) -> Key: + """The most recently created :class:`.Key`.""" + return next(reversed(self._keys.values())) + + # Access to Key properties + @property + def name(self) -> str: + """Name of the :attr:`.base` Key.""" + return self.base.name + + @property + def dims(self) -> Tuple[str, ...]: + """Dimensions of the :attr:`.base` Key.""" + return self.base.dims + + @property + def tag(self) -> Optional[str]: + """Tag of the :attr:`.base` Key.""" + return self.base.tag + + def __add__(self, other: str) -> "KeySeq": + return KeySeq(self.base + other) + + def __mul__(self, other) -> "KeySeq": + return KeySeq(self.base * other) + + def __sub__(self, other: Union[str, Iterable[str]]) -> "KeySeq": + return KeySeq(self.base - other) + + def __truediv__(self, other) -> "KeySeq": + return KeySeq(self.base / other) + + #: Type shorthand for :class:`Key` or any other value that can be used as a key. KeyLike = Union[Key, str] diff --git a/genno/testing/__init__.py b/genno/testing/__init__.py index 766cb31d..7fced4cf 100644 --- a/genno/testing/__init__.py +++ b/genno/testing/__init__.py @@ -435,24 +435,28 @@ def raises_or_warns(value, *args, **kwargs) -> ContextManager: Examples -------- - @pytest.mark.parametrize( - "input, output", (("FOO", 1), ("BAR", pytest.raises(ValueError))) - ) - def test_myfunc0(input, expected): - with raises_or_warns(expected, DeprecationWarning, match="FOO"): - assert expected == myfunc(input) + .. code-block:: python + + @pytest.mark.parametrize( + "input, output", (("FOO", 1), ("BAR", pytest.raises(ValueError))) + ) + def test_myfunc0(input, expected): + with raises_or_warns(expected, DeprecationWarning, match="FOO"): + assert expected == myfunc(input) In this example: - :py:`myfunc("FOO")` is expected to emit :class:`DeprecationWarning` and return 1. - :py:`myfunc("BAR")` is expected to raise :class:`ValueError` and issue no warning. - @pytest.mark.parametrize( - "input, output", (("FOO", 1), ("BAR", pytest.raises(ValueError))) - ) - def test_myfunc1(input, expected): - with raises_or_warns(expected, None): - assert expected == myfunc(input) + .. code-block:: python + + @pytest.mark.parametrize( + "input, output", (("FOO", 1), ("BAR", pytest.raises(ValueError))) + ) + def test_myfunc1(input, expected): + with raises_or_warns(expected, None): + assert expected == myfunc(input) In this example, no warnings are expected from :py:`myfunc("FOO")`. """ diff --git a/genno/tests/core/test_key.py b/genno/tests/core/test_key.py index 2df7f42c..2d4cf788 100644 --- a/genno/tests/core/test_key.py +++ b/genno/tests/core/test_key.py @@ -1,6 +1,6 @@ import pytest -from genno import Key +from genno import Key, KeySeq from genno.core.key import iter_keys, single_key from genno.testing import raises_or_warns @@ -135,6 +135,72 @@ def test_operations(self): key / 3.3 +class TestKeySeq: + @pytest.fixture + def ks(self) -> KeySeq: + return KeySeq("foo:x-y-z:bar") + + def test_call(self, ks) -> None: + assert "foo:x-y-z:bar+0" == ks() + assert "foo:x-y-z:bar+1" == ks() + assert "foo:x-y-z:bar+2" == ks() + assert "foo:x-y-z:bar+2" == ks.prev + + # Continues from interruption + ks[5] + assert "foo:x-y-z:bar+6" == next(ks) + + def test_getitem(self, ks) -> None: + assert "foo:x-y-z:bar+baz" == ks["baz"] + assert "foo:x-y-z:bar+qux" == ks["qux"] + assert "foo:x-y-z:bar+qux" == ks.prev + assert "foo:x-y-z:bar+0" == next(ks) + assert "foo:x-y-z:bar+0" == ks.prev + + def test_keys(sefl, ks) -> None: + ks["foo"] + ks[5] + next(ks) + ks["baz"] + ks[0] + + # .keys preserves order of creation + assert ("foo", 5, 6, "baz", 0) == tuple(ks.keys) + + def test_next(self, ks) -> None: + assert "foo:x-y-z:bar+0" == next(ks) + assert "foo:x-y-z:bar+1" == next(ks) + assert "foo:x-y-z:bar+2" == next(ks) + assert "foo:x-y-z:bar+2" == ks.prev + + # Continues from interruption + ks[5] + assert "foo:x-y-z:bar+6" == next(ks) + + def test_repr(self, ks) -> None: + assert "" == repr(ks) + + def test_key_attrs(self, ks) -> None: + assert "foo" == ks.name + assert ("x", "y", "z") == ks.dims + assert "bar" == ks.tag + + def test_key_ops(self, ks) -> None: + # __add__ + assert "foo:x-y-z:bar+baz" == (ks + "baz").base + + # __mul__ + assert "foo:w-x-y-z:bar" == (ks * "w").base + + # __sub__ + assert "foo:x-y-z" == (ks - "bar").base + with pytest.raises(ValueError): + ks - "qux" + + # __truediv__ + assert "foo:x-z:bar" == (ks / "y").base + + def test_sorted(): k1 = Key("foo", "abc") k2 = Key("foo", "cba")