Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KeySeq class #126

Merged
merged 8 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
cache: pip
cache-dependency-path: "**/pyproject.toml"

- uses: ts-graphviz/setup-graphviz@v1
- uses: ts-graphviz/setup-graphviz@v2
with:
macos-skip-brew-update: true

Expand Down Expand Up @@ -85,12 +85,13 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with: { python-version: 3.x }
- run: python -m pip install pre-commit && python -m pip freeze --local
- uses: actions/cache@v4
with:
python-version: "3.x"

path: ~/.cache/pre-commit
key: pre-commit|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }}
- name: Force recreation of pre-commit virtual environment for mypy
if: github.event_name == 'schedule' # Comment this line to run on a PR
run: gh cache list -L 999 | cut -f2 | grep pre-commit | xargs -I{} gh cache delete "{}" || true
env: { GH_TOKEN: "${{ github.token }}" }

- uses: pre-commit/[email protected]
run: pre-commit clean
- run: pre-commit run --all-files --color=always --show-diff-on-failure
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
- xarray
args: []
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.14
rev: v0.2.1
hooks:
- id: ruff
- id: ruff-format
Expand Down
85 changes: 77 additions & 8 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Top-level classes and functions
configure
Computer
Key
KeySeq
Quantity

.. autofunction:: configure
Expand Down Expand Up @@ -259,15 +260,19 @@ Top-level classes and functions

Keys can also be manipulated using some of the Python arithmetic operators:

- :py:`+`: same as :meth:`.add_tag`:
- :py:`+`: and :py:`-`: manipulate :attr:`.tag`, same as :meth:`.add_tag` and :meth:`.remove_tag` respectively:

>>> k1 = Key("foo", "abc")
>>> k1 = Key("foo", "abc", "bar+baz+qux")
>>> k1
<foo:a-b-c>
>>> k1 + "tag"
<foo:a-b-c:tag>
<foo:a-b-c:bar+baz+qux>
>>> k2 + "newtag"
<foo:a-b-c:bar+baz+qux+newtag>
>>> k1 - "baz"
<foo:a-b-c:bar+qux>
>>> k1 - ("bar", "baz")
<foo:a-b-c:qux>

- :py:`*` with a single string, an iterable of strings, or another Key: similar to :meth:`.append` and :meth:`.product`:
- :py:`*` and :py:`/`: manipulate :attr:`dims`, similar to :meth:`.append`/:attr:`.product` and :attr:`.drop`, respectively:

>>> k1 * "d"
<foo:a-b-c-d>
Expand All @@ -276,15 +281,79 @@ Top-level classes and functions
>>> k1 * Key("bar", "ghi")
<foo:a-b-c-g-h-i>

- :py:`/` with a single string or iterable of strings: similar to :meth:`drop`:

>>> k1 / "a"
<foo:b-c>
>>> k1 / ("a", "c")
<foo:b>
>>> k1 / Key("baz", "cde")
<foo:a-b>

.. autoclass:: genno.KeySeq
:members:

When preparing chains or complicated graphs of computations, it can be useful to use a sequence or set of similar keys to refer to the intermediate steps.
The :class:`.KeySeq` class is provided for this purpose.
It supports several ways to create related keys starting from a *base key*:

>>> ks = KeySeq("foo:x-y-z:bar")

One may:

- Use item access syntax:

>>> ks["a"]
<foo:x-y-z:bar+a>
>>> ks["b"]
<foo:x-y-z:bar+b>

- Use the Python built-in :func:`.next`.
This always returns the next key in a sequence of integers, starting with :py:`0` and continuing from the *highest previously created Key*:

>>> next(ks)
<foo:x-y-z:bar+0>

# Skip some values
>>> ks[5]
<foo:x-y-z:bar+5>

# next() continues from the highest
>>> next(ks)
<foo:x-y-z:bar+6>

- Treat the KeySeq as callable, optionally with any value that has a :class:`.str` representation:

>>> ks("c")
<foo:x-y-z:bar+c>

# Same as next()
>>> ks()
<foo:x-y-z:bar+7>

- Access the most recently generated item:

>>> ks.prev
<foo:x-y-z:bar+7>

- Access the base Key or its properties:

>>> ks.base
<foo:x-y-z:bar>
>>> ks.name
"foo"

- Access a :class:`dict` of all previously-created keys.
Because :class:`dict` is order-preserving, the order of keys and values reflects the order in which they were created:

>>> tuple(ks.keys)
("a", "b", 0, 5, 6, "a", 7)

The same Python arithmetic operators usable with Key are usable with KeySeq; they return a new KeySeq with a different :attr:`~.KeySeq.base`:

>>> ks * "w"
<KeySeq from 'foo:x-y-z-w:bar'>
>>> ks / ("x", "z")
<KeySeq from 'foo:z:bar'>

.. autoclass:: genno.Quantity
:members:
:inherited-members: pipe, shape, size
Expand Down
7 changes: 5 additions & 2 deletions doc/whatsnew.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
What's new
**********

.. Next release
.. ============
Next release
============

- Add :class:`.KeySeq` class for creating sequences or sets of similar :class:`Keys <.Key>` (:pull:`126`).
- Add :meth:`.Key.remove_tag` method and support for :py:`k - "foo"` syntax for removing tags from :class:`.Key` (:pull:`126`).

v1.23.1 (2024-02-01)
====================
Expand Down
3 changes: 2 additions & 1 deletion genno/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from .config import configure
from .core.computer import Computer
from .core.exceptions import ComputationError, KeyExistsError, MissingKeyError
from .core.key import Key
from .core.key import Key, KeySeq
from .core.operator import Operator
from .core.quantity import Quantity

__all__ = [
"ComputationError",
"Computer",
"Key",
"KeySeq",
"KeyExistsError",
"MissingKeyError",
"Operator",
Expand Down
150 changes: 127 additions & 23 deletions genno/core/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,20 @@
import re
from functools import partial, singledispatch
from itertools import chain, compress
from typing import Callable, Generator, Iterable, Iterator, Optional, Tuple, Union
from types import MappingProxyType
from typing import (
Callable,
Dict,
Generator,
Hashable,
Iterable,
Iterator,
Optional,
Sequence,
SupportsInt,
Tuple,
Union,
)
from warnings import warn

from genno.core.quantity import Quantity
Expand Down Expand Up @@ -184,33 +197,37 @@ def product(cls, new_name: str, *keys, tag: Optional[str] = None) -> "Key":
# Return new key. Use dict to keep only unique *dims*, in same order
return cls(new_name, dict.fromkeys(dims).keys()).add_tag(tag)

def __add__(self, other) -> "Key":
if isinstance(other, str):
return self.add_tag(other)
else:
def __add__(self, other: str) -> "Key":
if not isinstance(other, str):
raise TypeError(type(other))
return self.add_tag(other)

def __sub__(self, other: Union[str, Iterable[str]]) -> "Key":
return self.remove_tag(*((other,) if isinstance(other, str) else other))

def __mul__(self, other) -> "Key":
def __mul__(self, other: Union[str, "Key", Sequence[str]]) -> "Key":
if isinstance(other, str):
return self.append(other)
other_dims: Sequence[str] = (other,)
elif isinstance(other, Key):
other_dims = other.dims
elif isinstance(other, Sequence):
other_dims = other
else:
# Key or iterable of dims
other_dims = getattr(other, "dims", other)
try:
return self.append(*other_dims)
except Exception:
raise TypeError(type(other))

def __truediv__(self, other) -> "Key":
raise TypeError(type(other))

return self.append(*other_dims)

def __truediv__(self, other: Union[str, "Key", Sequence[str]]) -> "Key":
if isinstance(other, str):
return self.drop(other)
other_dims: Sequence[str] = (other,)
elif isinstance(other, Key):
other_dims = other.dims
elif isinstance(other, Sequence):
other_dims = other
else:
# Key or iterable of dims
other_dims = getattr(other, "dims", other)
try:
return self.drop(*other_dims)
except Exception:
raise TypeError(type(other))
raise TypeError(type(other))

return self.drop(*other_dims)

def __repr__(self) -> str:
"""Representation of the Key, e.g. '<name:dim1-dim2-dim3:tag>."""
Expand Down Expand Up @@ -295,7 +312,7 @@ def append(self, *dims: str) -> "Key":
"""Return a new Key with additional dimensions `dims`."""
return Key(self._name, list(self._dims) + list(dims), self._tag, _fast=True)

def add_tag(self, tag) -> "Key":
def add_tag(self, tag: Optional[str]) -> "Key":
"""Return a new Key with `tag` appended."""
return Key(
self._name, self._dims, "+".join(filter(None, [self._tag, tag])), _fast=True
Expand All @@ -312,13 +329,100 @@ def iter_sums(self) -> Generator[Tuple["Key", Callable, "Key"], None, None]:
self,
)

def remove_tag(self, *tags: str) -> "Key":
"""Return a key with any of `tags` dropped.

Raises
------
ValueError
If none of `tags` are in :attr:`.tags`.
"""
new_tags = tuple(filter(lambda t: t not in tags, (self.tag or "").split("+")))
new_tag = "+".join(new_tags) if new_tags else None
if new_tag == self.tag:
raise ValueError(f"No existing tags {tags!r} to remove")
return Key(self._name, self._dims, new_tag, _fast=True)


@_name_dims_tag.register
def _(value: Key):
"""Return the (name, dims, tag) of an existing Key."""
return value._name, value._dims, value._tag


class KeySeq:
"""Utility class for generating similar :class:`Keys <.Key>`."""

#: Base :class:`.Key` of the sequence.
base: Key

# Keys that have been created.
_keys: Dict[Hashable, Key]

def __init__(self, *args, **kwargs):
self.base = Key(*args, **kwargs)
self._keys = {}

def _next_int_tag(self) -> int:
return max([-1] + [t for t in self._keys if isinstance(t, int)]) + 1

def __next__(self) -> Key:
return self[self._next_int_tag()]

def __call__(self, value: Optional[Hashable] = None) -> Key:
return next(self) if value is None else self[value]

def __getitem__(self, value: Hashable) -> Key:
tag = int(value) if isinstance(value, SupportsInt) else str(value)
result = self._keys[tag] = self.base + str(tag)
return result

def __repr__(self) -> str:
return f"<KeySeq from '{self.base!s}'>"

@property
def keys(self) -> MappingProxyType:
"""Read-only view of previously-created :class:`Keys <.Key>`.

In the form of a :class:`dict` mapping tags (:class:`int` or :class:`str`) to
:class:`.Key` values.
"""
return MappingProxyType(self._keys)

@property
def prev(self) -> Key:
"""The most recently created :class:`.Key`."""
return next(reversed(self._keys.values()))

# Access to Key properties
@property
def name(self) -> str:
"""Name of the :attr:`.base` Key."""
return self.base.name

@property
def dims(self) -> Tuple[str, ...]:
"""Dimensions of the :attr:`.base` Key."""
return self.base.dims

@property
def tag(self) -> Optional[str]:
"""Tag of the :attr:`.base` Key."""
return self.base.tag

def __add__(self, other: str) -> "KeySeq":
return KeySeq(self.base + other)

def __mul__(self, other) -> "KeySeq":
return KeySeq(self.base * other)

def __sub__(self, other: Union[str, Iterable[str]]) -> "KeySeq":
return KeySeq(self.base - other)

def __truediv__(self, other) -> "KeySeq":
return KeySeq(self.base / other)


#: Type shorthand for :class:`Key` or any other value that can be used as a key.
KeyLike = Union[Key, str]

Expand Down
Loading