From 78da5cccfcf188989d2aca9b28aec0e25e81bb96 Mon Sep 17 00:00:00 2001 From: Izaak Date: Fri, 26 Mar 2021 18:59:01 +0100 Subject: [PATCH 01/13] Fix small bug in JSON encoder (#64) * fix bug: don't forget last line, divide evenly, test data same before/after encoding * remove overcorrection --- src/correctionlib/JSONEncoder.py | 11 ++++++++--- tests/test_jsonencoder.py | 31 +++++++++++++++++++------------ 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/correctionlib/JSONEncoder.py b/src/correctionlib/JSONEncoder.py index c9d86740..37a92322 100755 --- a/src/correctionlib/JSONEncoder.py +++ b/src/correctionlib/JSONEncoder.py @@ -93,13 +93,18 @@ def encode(self, obj: Any) -> str: output.append(json.dumps(item)) retval = "[ " + ", ".join(output) + " ]" else: # break long list into multiple lines - nlines = math.ceil(len(obj) / float(self.maxlistlen)) - maxlen = int(len(obj) / nlines) + nlines = math.ceil( + len(obj) / float(self.maxlistlen) + ) # number of lines + maxlen = int( + math.ceil(len(obj) / nlines) + ) # divide evenly over nlines for i in range(0, nlines): line = [] for item in obj[i * maxlen : (i + 1) * maxlen]: line.append(json.dumps(item)) - output.append(", ".join(line)) + if line: + output.append(", ".join(line)) if not retval: lines = (",\n" + indent_str).join(output) # lines between brackets if ( diff --git a/tests/test_jsonencoder.py b/tests/test_jsonencoder.py index aa41612e..8fd1087a 100755 --- a/tests/test_jsonencoder.py +++ b/tests/test_jsonencoder.py @@ -1,4 +1,4 @@ -from correctionlib.JSONEncoder import dumps +from correctionlib.JSONEncoder import dumps, json def test_jsonencode(): @@ -125,6 +125,8 @@ def test_jsonencode(): breakbrackets=False, ) + retrieved = json.loads(formatted) + expected = """\ { "layer1": { @@ -167,12 +169,12 @@ def test_jsonencode(): [ "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z" ], - [ "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", - "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "1", "2" + [ "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", + "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "1", "2", "3" ], - [ "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", - "r", "s", "t", "u", "v", "w", "x", "y", "z", "a", "b", "c", "d", "e", "f", "g", "h", - "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y" + [ "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", + "s", "t", "u", "v", "w", "x", "y", "z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", + "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z" ], [ "this is short", "very short" ], [ "this is medium long", "verily, can you see?" ], @@ -205,8 +207,8 @@ def test_jsonencode(): [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26 ], - [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, - 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26 + [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27 ], [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 @@ -221,9 +223,9 @@ def test_jsonencode(): 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51 ], - [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, - 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51 + [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52 ] ], "layer3_6": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ], @@ -244,4 +246,9 @@ def test_jsonencode(): } } }""" - assert formatted == expected, f"Found:\n {formatted}" + assert ( + formatted == expected + ), f"Formatted does not match expected:\nExpected: {expected}\nFormatted: {formatted}" + assert ( + retrieved == data + ), f"Data before and after encoding do not match:\nBefore: {data}\nFormatted: {formatted}" From 0efea75272b2e44731005c580183afd7a21a06e9 Mon Sep 17 00:00:00 2001 From: Nick Smith Date: Fri, 26 Mar 2021 17:10:35 -0500 Subject: [PATCH 02/13] CI: include same MSVC workaround for wheel building --- .github/workflows/wheels.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 0c0e0266..c7611ed8 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -35,6 +35,13 @@ jobs: with: submodules: recursive + # workaround for MSVC, can be removed when scikit-build includes https://github.com/scikit-build/scikit-build/pull/526 + - name: Prepare compiler environment for Windows + if: runner.os == 'Windows' + uses: ilammy/msvc-dev-cmd@v1 + with: + arch: x64 + - uses: joerick/cibuildwheel@v1.10.0 env: CIBW_SKIP: cp27* From 1c1b8f890aca98ae102f9e38f8cd18a51740b606 Mon Sep 17 00:00:00 2001 From: Nick Smith Date: Fri, 26 Mar 2021 18:20:13 -0500 Subject: [PATCH 03/13] CI: dedicated windows bubild workflow for x64 and win32 --- .github/workflows/wheels.yml | 42 +++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index c7611ed8..d8551307 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -28,26 +28,52 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-18.04, windows-latest, macos-latest] + os: [ubuntu-18.04, macos-latest] steps: - uses: actions/checkout@v2 with: submodules: recursive - # workaround for MSVC, can be removed when scikit-build includes https://github.com/scikit-build/scikit-build/pull/526 - - name: Prepare compiler environment for Windows - if: runner.os == 'Windows' - uses: ilammy/msvc-dev-cmd@v1 + - uses: joerick/cibuildwheel@v1.10.0 + env: + CIBW_SKIP: cp27* + CIBW_TEST_EXTRAS: test + CIBW_TEST_COMMAND: pytest {project}/tests + MACOSX_DEPLOYMENT_TARGET: 10.14 + + - name: Upload wheels + uses: actions/upload-artifact@v2 + with: + path: wheelhouse/*.whl + + # workaround for MSVC, can be removed when scikit-build includes https://github.com/scikit-build/scikit-build/pull/526 + build_windows_wheels: + name: Wheels for Windows + runs-on: windows-latest + + steps: + - uses: actions/checkout@v2 + with: + submodules: recursive + + - uses: ilammy/msvc-dev-cmd@v1 + + - uses: joerick/cibuildwheel@v1.10.0 + env: + CIBW_BUILD: cp3*-win32 + CIBW_TEST_EXTRAS: test + CIBW_TEST_COMMAND: pytest {project}/tests + + - uses: ilammy/msvc-dev-cmd@v1 with: arch: x64 - uses: joerick/cibuildwheel@v1.10.0 env: - CIBW_SKIP: cp27* + CIBW_BUILD: cp3*-win_amd64 CIBW_TEST_EXTRAS: test CIBW_TEST_COMMAND: pytest {project}/tests - MACOSX_DEPLOYMENT_TARGET: 10.14 - name: Upload wheels uses: actions/upload-artifact@v2 @@ -55,7 +81,7 @@ jobs: path: wheelhouse/*.whl upload_all: - needs: [build_wheels, make_sdist] + needs: [build_wheels, build_windows_wheels, make_sdist] runs-on: ubuntu-latest if: github.event_name == 'release' && github.event.action == 'published' From 904f69a53646c4cfc99cfdfb61a45e352c7d1998 Mon Sep 17 00:00:00 2001 From: Nick Smith Date: Fri, 26 Mar 2021 18:27:33 -0500 Subject: [PATCH 04/13] CI: specify explicitly win32 --- .github/workflows/wheels.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index d8551307..d98cd0f3 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -58,6 +58,8 @@ jobs: submodules: recursive - uses: ilammy/msvc-dev-cmd@v1 + with: + arch: x86 - uses: joerick/cibuildwheel@v1.10.0 env: From 927b18e867b57b320824285563afb78d0e37691e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 7 Apr 2021 08:50:59 -0500 Subject: [PATCH 05/13] Bump pre-commit/action from v2.0.0 to v2.0.2 (#67) Bumps [pre-commit/action](https://github.com/pre-commit/action) from v2.0.0 to v2.0.2. - [Release notes](https://github.com/pre-commit/action/releases) - [Commits](https://github.com/pre-commit/action/compare/v2.0.0...9cf68dc1ace5504cd0e05b9f3df32e6a0822ad89) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 09c34548..d9420c6b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ jobs: with: submodules: recursive - uses: actions/setup-python@v2 - - uses: pre-commit/action@v2.0.0 + - uses: pre-commit/action@v2.0.2 with: extra_args: --hook-stage manual --all-files From ed55c071df050dccc086efeb82c36a375762f148 Mon Sep 17 00:00:00 2001 From: Nicholas Smith Date: Thu, 8 Apr 2021 10:56:52 -0500 Subject: [PATCH 06/13] Update links in README to cms-nanoAOD org --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 4e6ef195..a2a4864e 100644 --- a/README.md +++ b/README.md @@ -117,18 +117,18 @@ Some examples can be found in `data/conversion.py`. The `tests/` directory may a ## Developing See CONTRIBUTING.md -[actions-badge]: https://github.com/nsmith-/correctionlib/workflows/CI/badge.svg -[actions-link]: https://github.com/nsmith-/correctionlib/actions +[actions-badge]: https://github.com/cms-nanoAOD/correctionlib/workflows/CI/badge.svg +[actions-link]: https://github.com/cms-nanoAOD/correctionlib/actions [black-badge]: https://img.shields.io/badge/code%20style-black-000000.svg [black-link]: https://github.com/psf/black [conda-badge]: https://img.shields.io/conda/vn/conda-forge/correctionlib [conda-link]: https://github.com/conda-forge/correctionlib-feedstock [github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github -[github-discussions-link]: https://github.com/nsmith-/correctionlib/discussions -[gitter-badge]: https://badges.gitter.im/https://github.com/nsmith-/correctionlib/community.svg -[gitter-link]: https://gitter.im/https://github.com/nsmith-/correctionlib/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge +[github-discussions-link]: https://github.com/cms-nanoAOD/correctionlib/discussions +[gitter-badge]: https://badges.gitter.im/https://github.com/cms-nanoAOD/correctionlib/community.svg +[gitter-link]: https://gitter.im/https://github.com/cms-nanoAOD/correctionlib/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge [pypi-link]: https://pypi.org/project/correctionlib/ [pypi-platforms]: https://img.shields.io/pypi/pyversions/correctionlib [pypi-version]: https://badge.fury.io/py/correctionlib.svg -[rtd-badge]: https://github.com/nsmith-/correctionlib/actions/workflows/docs.yml/badge.svg -[rtd-link]: https://nsmith-.github.io/correctionlib/ +[rtd-badge]: https://github.com/cms-nanoAOD/correctionlib/actions/workflows/docs.yml/badge.svg +[rtd-link]: https://cms-nanoAOD.github.io/correctionlib/ From 2d0f3d5e11d802fbcfbb3e36d3ff1b9501cab57c Mon Sep 17 00:00:00 2001 From: Nicholas Smith Date: Tue, 13 Apr 2021 11:49:46 -0500 Subject: [PATCH 07/13] Command line interface (#68) * Create a highlevel wrapper combining model and evaluator together Also set up a CLI framework * Actually add CLI * Configurable output * Use mapping from typing * Sort values * Import windows DLL before highlevel --- .gitignore | 4 + setup.cfg | 10 +- src/correctionlib/__init__.py | 10 +- src/correctionlib/_core/__init__.pyi | 23 +++++ src/correctionlib/cli.py | 99 ++++++++++++++++++++ src/correctionlib/highlevel.py | 105 +++++++++++++++++++++ src/correctionlib/schemav2.py | 134 ++++++++++++++++++++++++++- tests/test_core.py | 2 +- tests/test_highlevel.py | 30 ++++++ 9 files changed, 410 insertions(+), 7 deletions(-) create mode 100644 src/correctionlib/_core/__init__.pyi create mode 100644 src/correctionlib/cli.py create mode 100644 src/correctionlib/highlevel.py create mode 100644 tests/test_highlevel.py diff --git a/.gitignore b/.gitignore index b476f2a5..7433f651 100644 --- a/.gitignore +++ b/.gitignore @@ -140,3 +140,7 @@ cython_debug/ # setuptools_scm src/*/version.py + +_skbuild +src/correctionlib/cmake +src/correctionlib/include diff --git a/setup.cfg b/setup.cfg index fafa0f94..cd0dffe1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ maintainer_email = nick.smith@cern.ch description = A generic correction library long_description = file: README.md long_description_content_type = text/markdown -url = https://github.com/nsmith-/correctionlib +url = https://github.com/cms-nanoAOD/correctionlib license = BSD 3-Clause License # keywords = platforms = @@ -34,6 +34,7 @@ install_requires = typing-extensions;python_version<"3.8" dataclasses;python_version<"3.7" pydantic >=1.7.3 + rich [options.extras_require] test = @@ -54,6 +55,10 @@ convert = uproot >=4.0.4 requests +[options.entry_points] +console_scripts = + correction = correctionlib.cli:main + [tool:pytest] addopts = -rs -s -Wd testpaths = @@ -108,3 +113,6 @@ strict_equality = True [mypy-numpy] ignore_missing_imports = True + +[mypy-uproot] +ignore_missing_imports = True diff --git a/src/correctionlib/__init__.py b/src/correctionlib/__init__.py index 7656d8db..f4a781d6 100644 --- a/src/correctionlib/__init__.py +++ b/src/correctionlib/__init__.py @@ -1,7 +1,3 @@ -from .version import version as __version__ - -__all__ = ("__version__",) - import sys if sys.platform.startswith("win32"): @@ -9,3 +5,9 @@ import os.path ctypes.CDLL(os.path.join(os.path.dirname(__file__), "lib", "correctionlib.dll")) + + +from .highlevel import Correction, CorrectionSet +from .version import version as __version__ + +__all__ = ("__version__", "CorrectionSet", "Correction") diff --git a/src/correctionlib/_core/__init__.pyi b/src/correctionlib/_core/__init__.pyi new file mode 100644 index 00000000..d7d00b00 --- /dev/null +++ b/src/correctionlib/_core/__init__.pyi @@ -0,0 +1,23 @@ +from typing import Iterator, Type, TypeVar, Union + +class Correction: + @property + def name(self) -> str: ... + @property + def description(self) -> str: ... + @property + def version(self) -> int: ... + def evaluate(self, *args: Union[str, int, float]) -> float: ... + +T = TypeVar("T", bound="CorrectionSet") + +class CorrectionSet: + @classmethod + def from_file(cls: Type[T], filename: str) -> T: ... + @classmethod + def from_string(cls: Type[T], data: str) -> T: ... + @property + def schema_version(self) -> int: ... + def __getitem__(self, key: str) -> Correction: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[str]: ... diff --git a/src/correctionlib/cli.py b/src/correctionlib/cli.py new file mode 100644 index 00000000..c01e3424 --- /dev/null +++ b/src/correctionlib/cli.py @@ -0,0 +1,99 @@ +"""Command-line interface to correctionlib + +""" +import argparse + +from rich.console import Console + +from correctionlib.highlevel import model_auto, open_auto + + +def validate(console: Console, args: argparse.Namespace) -> int: + """Check if all files are valid""" + retcode = 0 + for file in args.files: + try: + if not args.quiet: + console.rule(f"[blue]Validating file {file}") + cset = model_auto(open_auto(file)) + if args.version and cset.schema_version != args.version: + raise ValueError( + f"Schema version {cset.schema_version} does not match the required version {args.version}" + ) + except Exception as ex: + if not args.quiet: + console.print(str(ex)) + retcode = 1 + if args.failfast: + break + else: + if not args.quiet: + console.print("[green]All OK :heavy_check_mark:") + return retcode + + +def setup_validate(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser("validate", help=validate.__doc__) + parser.set_defaults(command=validate) + parser.add_argument( + "--quiet", + "-q", + action="store_true", + help="Suppress error printout, only produce a returncode", + ) + parser.add_argument( + "--failfast", + "-f", + action="store_true", + help="Fail on first invalid file", + ) + parser.add_argument( + "--version", + "-v", + type=int, + default=None, + help="Validate against specific schema version", + ) + parser.add_argument("files", nargs="+", metavar="FILE") + + +def summary(console: Console, args: argparse.Namespace) -> int: + for file in args.files: + console.rule(f"[blue]Corrections in file {file}") + cset = model_auto(open_auto(file)) + console.print(cset) + return 0 + + +def setup_summary(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser( + "summary", help="Print a summmary of the corrections" + ) + parser.set_defaults(command=summary) + parser.add_argument("files", nargs="+", metavar="FILE") + + +def main() -> int: + parser = argparse.ArgumentParser(prog="correction", description=__doc__) + parser.add_argument( + "--width", + type=int, + default=100, + help="Rich output width", + ) + parser.add_argument("--html", type=str, help="Save HTML output to a file") + subparsers = parser.add_subparsers() + setup_validate(subparsers) + setup_summary(subparsers) + args = parser.parse_args() + + console = Console(width=args.width, record=True) + # py3.7: subparsers has required=True option + if hasattr(args, "command"): + retcode: int = args.command(console, args) + if args.html: + console.save_html(args.html) + return retcode + + parser.parse_args(["-h"]) + return 0 diff --git a/src/correctionlib/highlevel.py b/src/correctionlib/highlevel.py new file mode 100644 index 00000000..da414ba2 --- /dev/null +++ b/src/correctionlib/highlevel.py @@ -0,0 +1,105 @@ +"""High-level correctionlib objects + +""" +import json +from numbers import Integral +from typing import Any, Iterator, Mapping, Optional, Union + +import correctionlib._core +import correctionlib.version + + +def open_auto(filename: str) -> Any: + """Open a file and return a deserialized json object""" + if filename.endswith(".json.gz"): + import gzip + + with gzip.open(filename, "r") as gzfile: + return json.load(gzfile) + elif filename.endswith(".json"): + with open(filename) as file: + return json.load(file) + raise ValueError(f"{filename}: unrecognized file format, expected .json, .json.gz") + + +def model_auto(data: Any) -> Any: + """Read schema version from json object and construct appropriate model""" + if not isinstance(data, dict): + raise ValueError("CorrectionSet is not a dictionary!") + version = data.get("schema_version", None) + if version is None: + raise ValueError("CorrectionSet has no schema version!") + if not isinstance(version, Integral): + raise ValueError(f"CorrectionSet schema version ({version}) is not an integer!") + if version == 1: + import correctionlib.schemav1 + + return correctionlib.schemav1.CorrectionSet.parse_obj(data) + elif version == 2: + import correctionlib.schemav2 + + return correctionlib.schemav2.CorrectionSet.parse_obj(data) + raise ValueError(f"Unknown CorrectionSet schema version ({version})") + + +class Correction: + def __init__(self, base: correctionlib._core.Correction): + self._base = base + + @property + def name(self) -> str: + return self._base.name + + @property + def description(self) -> str: + return self._base.description + + @property + def version(self) -> int: + return self._base.version + + def evaluate(self, *args: Union[str, int, float]) -> float: + return self._base.evaluate(*args) + + +class CorrectionSet(Mapping[str, Correction]): + def __init__(self, model: Any, *, schema_version: Optional[int] = None): + if schema_version is None: + this_version = correctionlib.version.version_tuple[0] + if model.schema_version < this_version: + # TODO: upgrade schema automatically + raise NotImplementedError( + "Cannot read CorrectionSet models older than {this_version}" + ) + elif schema_version != model.schema_version: + raise ValueError( + f"CorrectionSet schema version ({model.schema_version}) differs from desired version ({schema_version})" + ) + self._model = model + self._base = correctionlib._core.CorrectionSet.from_string(model.json()) + + @classmethod + def from_file( + cls, filename: str, *, schema_version: Optional[int] = None + ) -> "CorrectionSet": + return cls(model_auto(open_auto(filename)), schema_version=schema_version) + + @classmethod + def from_string( + cls, data: str, *, schema_version: Optional[int] = None + ) -> "CorrectionSet": + return cls(model_auto(json.loads(data)), schema_version=schema_version) + + @property + def schema_version(self) -> int: + return self._base.schema_version + + def __getitem__(self, key: str) -> Correction: + corr = self._base.__getitem__(key) + return Correction(corr) + + def __len__(self) -> int: + return len(self._base) + + def __iter__(self) -> Iterator[str]: + return iter(self._base) diff --git a/src/correctionlib/schemav2.py b/src/correctionlib/schemav2.py index 8ab4e00e..921d94eb 100644 --- a/src/correctionlib/schemav2.py +++ b/src/correctionlib/schemav2.py @@ -1,6 +1,11 @@ -from typing import Any, List, Optional, Union +from collections import defaultdict +from typing import Any, Dict, List, Optional, Set, Tuple, Union from pydantic import BaseModel, Field, StrictInt, StrictStr, validator +from rich.columns import Columns +from rich.console import Console, ConsoleOptions, RenderResult +from rich.panel import Panel +from rich.tree import Tree try: from typing import Literal # type: ignore @@ -16,6 +21,16 @@ class Config: extra = "forbid" +class _SummaryInfo: + def __init__(self) -> None: + self.values: Set[Union[str, int]] = set() + self.default: bool = False + self.overflow: bool = True + self.transform: bool = False + self.min: float = float("inf") + self.max: float = float("-inf") + + class Variable(Model): """An input or output variable""" @@ -27,6 +42,11 @@ class Variable(Model): description="A nice description of what this variable means" ) + def __rich__(self) -> str: + msg = f"[bold]{self.name}[/bold] ({self.type})\n" + msg += self.description or "[i]No description[/i]" + return msg + # py3.7+: ForwardRef can be used instead of strings Content = Union[ @@ -47,6 +67,11 @@ class Formula(Model): description="Parameters, if the parser supports them (e.g. [0] for TFormula)" ) + def summarize( + self, nodecount: Dict[str, int], inputstats: Dict[str, _SummaryInfo] + ) -> None: + nodecount["Formula"] += 1 + class FormulaRef(Model): """A reference to one of the Correction generic_formula items, with specific parameters""" @@ -59,6 +84,11 @@ class FormulaRef(Model): description="Same interpretation as Formula.parameters" ) + def summarize( + self, nodecount: Dict[str, int], inputstats: Dict[str, _SummaryInfo] + ) -> None: + nodecount["FormulaRef"] += 1 + class Transform(Model): """A node that rewrites one real or integer input according to a rule as given by a content node @@ -76,6 +106,14 @@ class Transform(Model): description="A subtree that will be evaluated with transformed values" ) + def summarize( + self, nodecount: Dict[str, int], inputstats: Dict[str, _SummaryInfo] + ) -> None: + nodecount["Transform"] += 1 + inputstats[self.input].transform = True + if not isinstance(self.content, float): + self.content.summarize(nodecount, inputstats) + class Binning(Model): """1-dimensional binning in an input variable""" @@ -109,6 +147,19 @@ def validate_content(cls, content: List[Content], values: Any) -> List[Content]: ) return content + def summarize( + self, nodecount: Dict[str, int], inputstats: Dict[str, _SummaryInfo] + ) -> None: + nodecount["Binning"] += 1 + inputstats[self.input].overflow &= self.flow != "error" + inputstats[self.input].min = min(inputstats[self.input].min, self.edges[0]) + inputstats[self.input].max = max(inputstats[self.input].max, self.edges[-1]) + for item in self.content: + if not isinstance(item, float): + item.summarize(nodecount, inputstats) + if not isinstance(self.flow, (float, str)): + self.flow.summarize(nodecount, inputstats) + class MultiBinning(Model): """N-dimensional rectangular binning""" @@ -151,6 +202,20 @@ def validate_content(cls, content: List[Content], values: Any) -> List[Content]: ) return content + def summarize( + self, nodecount: Dict[str, int], inputstats: Dict[str, _SummaryInfo] + ) -> None: + nodecount["MultiBinning"] += 1 + for input, edges in zip(self.inputs, self.edges): + inputstats[input].overflow &= self.flow != "error" + inputstats[input].min = min(inputstats[input].min, edges[0]) + inputstats[input].max = max(inputstats[input].max, edges[-1]) + for item in self.content: + if not isinstance(item, float): + item.summarize(nodecount, inputstats) + if not isinstance(self.flow, (float, str)): + self.flow.summarize(nodecount, inputstats) + class CategoryItem(Model): """A key-value pair @@ -186,6 +251,18 @@ def validate_content(cls, content: List[CategoryItem]) -> List[CategoryItem]: raise ValueError("Duplicate keys detected in Category node") return content + def summarize( + self, nodecount: Dict[str, int], inputstats: Dict[str, _SummaryInfo] + ) -> None: + nodecount["Category"] += 1 + inputstats[self.input].values |= {item.key for item in self.content} + inputstats[self.input].default |= self.default is not None + for item in self.content: + if not isinstance(item.value, float): + item.value.summarize(nodecount, inputstats) + if self.default and not isinstance(self.default, float): + self.default.summarize(nodecount, inputstats) + Transform.update_forward_refs() Binning.update_forward_refs() @@ -225,11 +302,66 @@ def validate_output(cls, output: Variable) -> Variable: ) return output + def summary(self) -> Tuple[Dict[str, int], Dict[str, _SummaryInfo]]: + nodecount: Dict[str, int] = defaultdict(int) + inputstats = {var.name: _SummaryInfo() for var in self.inputs} + if not isinstance(self.data, float): + self.data.summarize(nodecount, inputstats) + return nodecount, inputstats + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + yield f":chart_with_upwards_trend: [bold]{self.name}[/bold] (v{self.version})" + yield self.description or "[i]No description[/i]" + nodecount, inputstats = self.summary() + yield "Node counts: " + ", ".join( + f"[b]{key}[/b]: {val}" for key, val in nodecount.items() + ) + + def fmt_input(var: Variable, stats: _SummaryInfo) -> str: + out = var.__rich__() + if var.type == "real": + out += f"\nRange: [{stats.min}, {stats.max})" + if stats.overflow: + out += ", overflow ok" + if stats.transform: + out += "\n[bold red]has transform[/bold red]" + else: + out += "\nValues: " + ", ".join(str(v) for v in sorted(stats.values)) + if stats.default: + out += "\n[bold green]has default[/bold green]" + return out + + inputs = ( + Panel( + fmt_input(var, inputstats[var.name]), + title=":arrow_forward: input", + ) + for var in self.inputs + ) + yield Columns(inputs) + yield Panel( + self.output.__rich__(), + title=":arrow_backward: output", + expand=False, + ) + class CorrectionSet(Model): schema_version: Literal[VERSION] = Field(description="The overall schema version") corrections: List[Correction] + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + tree = Tree( + f":open_file_folder: CorrectionSet ([i]schema v{self.schema_version}[/i])" + ) + for corr in self.corrections: + tree.add(corr) + yield tree + if __name__ == "__main__": import os diff --git a/tests/test_core.py b/tests/test_core.py index c69b3448..a4b96c9f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -16,7 +16,7 @@ def wrap(*corrs): return core.CorrectionSet.from_string(cset.json()) -def test_evaluator_v1(): +def test_evaluator(): with pytest.raises(RuntimeError): cset = core.CorrectionSet.from_string("{") diff --git a/tests/test_highlevel.py b/tests/test_highlevel.py new file mode 100644 index 00000000..55fea92b --- /dev/null +++ b/tests/test_highlevel.py @@ -0,0 +1,30 @@ +import pytest + +import correctionlib +from correctionlib import schemav2 as model + + +def test_highlevel(): + cset = correctionlib.CorrectionSet( + model.CorrectionSet( + schema_version=model.VERSION, + corrections=[ + model.Correction( + name="test corr", + version=2, + inputs=[], + output=model.Variable(name="a scale", type="real"), + data=1.234, + ) + ], + ) + ) + assert set(cset) == {"test corr"} + sf = cset["test corr"] + assert sf.version == 2 + assert sf.description == "" + + with pytest.raises(RuntimeError): + sf.evaluate(0, 1.2, 35.0, 0.01) + + assert sf.evaluate() == 1.234 From 446efb29cff39c3e8d9a50243e7a5d557a86c0c1 Mon Sep 17 00:00:00 2001 From: Nicholas Smith Date: Mon, 26 Apr 2021 09:31:08 -0500 Subject: [PATCH 08/13] Merge tool (#71) * Allow CLI execution through `python -m correctionlib.cli` * Add a console merge tool --- src/correctionlib/cli.py | 49 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/src/correctionlib/cli.py b/src/correctionlib/cli.py index c01e3424..07f2ee93 100644 --- a/src/correctionlib/cli.py +++ b/src/correctionlib/cli.py @@ -2,6 +2,7 @@ """ import argparse +import sys from rich.console import Console @@ -73,6 +74,49 @@ def setup_summary(subparsers: argparse._SubParsersAction) -> None: parser.add_argument("files", nargs="+", metavar="FILE") +def merge(console: Console, args: argparse.Namespace) -> int: + cset = model_auto(open_auto(args.files[0])) + for file in args.files[1:]: + cset2 = model_auto(open_auto(file)) + if cset2.schema_version != cset.schema_version: + console.print("[red]Mixed schema versions detected") + return 1 + for corr2 in cset2.corrections: + if any(corr.name == corr2.name for corr in cset.corrections): + console.print( + f"[red]Correction '{corr2.name}' from {file} is a duplicate" + ) + return 1 + cset.corrections.append(corr2) + if args.format == "compact": + sys.stdout.write(cset.json()) + elif args.format == "indented": + sys.stdout.write(cset.json(indent=4) + "\n") + elif args.format == "pretty": + from correctionlib.JSONEncoder import dumps + + sys.stdout.write(dumps(cset) + "\n") + else: + return 1 + return 0 + + +def setup_merge(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser( + "merge", help="Merge one or more correction files and print to stdout" + ) + parser.set_defaults(command=merge) + parser.add_argument( + "-f", + "--format", + type=str, + help="JSON output formatting (default: %(default)s)", + choices=("compact", "indented", "pretty"), + default="compact", + ) + parser.add_argument("files", nargs="+", metavar="FILE") + + def main() -> int: parser = argparse.ArgumentParser(prog="correction", description=__doc__) parser.add_argument( @@ -85,6 +129,7 @@ def main() -> int: subparsers = parser.add_subparsers() setup_validate(subparsers) setup_summary(subparsers) + setup_merge(subparsers) args = parser.parse_args() console = Console(width=args.width, record=True) @@ -97,3 +142,7 @@ def main() -> int: parser.parse_args(["-h"]) return 0 + + +if __name__ == "__main__": + exit(main()) From 403ba236af1462ee0330efd9940e3642e990c7c4 Mon Sep 17 00:00:00 2001 From: Nicholas Smith Date: Mon, 26 Apr 2021 21:13:20 -0500 Subject: [PATCH 09/13] Vectorized evaluation (#72) * Implement vectorized evaluation for flat numpy arrays * Expose vectorized evaluation in highlevel * ipython key completions --- src/correction.cc | 7 ++- src/correctionlib/_core/__init__.pyi | 3 ++ src/correctionlib/highlevel.py | 26 ++++++++-- src/python.cc | 53 +++++++++++++++++++ tests/test_core_vect.py | 77 ++++++++++++++++++++++++++++ tests/test_highlevel.py | 16 +++++- 6 files changed, 176 insertions(+), 6 deletions(-) create mode 100644 tests/test_core_vect.py diff --git a/src/correction.cc b/src/correction.cc index f1be59d1..bf09a775 100644 --- a/src/correction.cc +++ b/src/correction.cc @@ -109,7 +109,12 @@ Formula::Formula(const rapidjson::Value& json, const Correction& context, bool g std::vector variableIdx; for (const auto& item : json["variables"].GetArray()) { - variableIdx.push_back(context.input_index(item.GetString())); + auto idx = context.input_index(item.GetString()); + if ( context.inputs()[idx].type() != Variable::VarType::real ) { + throw std::runtime_error("Formulas only accept real-valued inputs, got type " + + context.inputs()[idx].typeStr() + " for variable " + context.inputs()[idx].name()); + } + variableIdx.push_back(idx); } std::vector params; diff --git a/src/correctionlib/_core/__init__.pyi b/src/correctionlib/_core/__init__.pyi index d7d00b00..b34090fa 100644 --- a/src/correctionlib/_core/__init__.pyi +++ b/src/correctionlib/_core/__init__.pyi @@ -1,5 +1,7 @@ from typing import Iterator, Type, TypeVar, Union +import numpy + class Correction: @property def name(self) -> str: ... @@ -8,6 +10,7 @@ class Correction: @property def version(self) -> int: ... def evaluate(self, *args: Union[str, int, float]) -> float: ... + def evalv(self, *args: Union[numpy.ndarray, str, int, float]) -> numpy.ndarray: ... T = TypeVar("T", bound="CorrectionSet") diff --git a/src/correctionlib/highlevel.py b/src/correctionlib/highlevel.py index da414ba2..a9624779 100644 --- a/src/correctionlib/highlevel.py +++ b/src/correctionlib/highlevel.py @@ -3,7 +3,9 @@ """ import json from numbers import Integral -from typing import Any, Iterator, Mapping, Optional, Union +from typing import Any, Iterator, List, Mapping, Optional, Union + +import numpy import correctionlib._core import correctionlib.version @@ -58,8 +60,23 @@ def description(self) -> str: def version(self) -> int: return self._base.version - def evaluate(self, *args: Union[str, int, float]) -> float: - return self._base.evaluate(*args) + def evaluate( + self, *args: Union[numpy.ndarray, str, int, float] + ) -> Union[float, numpy.ndarray]: + # TODO: create a ufunc with numpy.vectorize in constructor? + vargs = [arg for arg in args if isinstance(arg, numpy.ndarray)] + if vargs: + bargs = numpy.broadcast_arrays(*vargs) + oshape = bargs[0].shape + bargs = (arg.flatten() for arg in bargs) + out = self._base.evalv( + *( + next(bargs) if isinstance(arg, numpy.ndarray) else arg + for arg in args + ) + ) + return out.reshape(oshape) + return self._base.evaluate(*args) # type: ignore class CorrectionSet(Mapping[str, Correction]): @@ -90,6 +107,9 @@ def from_string( ) -> "CorrectionSet": return cls(model_auto(json.loads(data)), schema_version=schema_version) + def _ipython_key_completions_(self) -> List[str]: + return list(self.keys()) + @property def schema_version(self) -> int: return self._base.schema_version diff --git a/src/python.cc b/src/python.cc index e274140f..e919253d 100644 --- a/src/python.cc +++ b/src/python.cc @@ -1,3 +1,4 @@ +#include #include #include #include "correction.h" @@ -14,6 +15,58 @@ PYBIND11_MODULE(_core, m) { .def_property_readonly("version", &Correction::version) .def("evaluate", [](Correction& c, py::args args) { return c.evaluate(py::cast>(args)); + }) + .def("evalv", [](Correction& c, py::args args) { + std::vector inputs; + inputs.reserve(py::len(args)); + std::vector> vargs; + if ( py::len(args) != c.inputs().size() ) { + throw std::invalid_argument("Incorrect number of inputs (got " + std::to_string(py::len(args)) + + ", expected " + std::to_string(c.inputs().size()) + ")"); + } + for (size_t i=0; i < py::len(args); ++i) { + if ( py::isinstance(args[i]) ) { + if ( c.inputs()[i].type() == Variable::VarType::integer ) { + vargs.emplace_back(i, py::cast>(args[i]).request()); + inputs.emplace_back(0); + } + else if ( c.inputs()[i].type() == Variable::VarType::real ) { + vargs.emplace_back(i, py::cast>(args[i]).request()); + inputs.emplace_back(0.0); + } + else { + throw std::invalid_argument("Array arguments only allowed for integer and real input types"); + } + + if ( vargs.back().second.ndim != 1 ) { + throw std::invalid_argument("Array arguments with dimension greater " + "than one are not supported (argument at position " + std::to_string(i) + ")"); + } + if ( vargs.back().second.size != vargs.front().second.size ) { + throw std::invalid_argument("Array arguments must all have the same size" + "(argument at position " + std::to_string(i) + " is length " + + std::to_string(vargs.back().second.size) + ")"); + } + } + else { + inputs.push_back(py::cast(args[i])); + } + } + auto output = py::array_t((vargs.size() > 0) ? vargs.front().second.size : 1); + py::buffer_info outbuffer = output.request(); + double * outptr = static_cast(outbuffer.ptr); + for (size_t i=0; i < outbuffer.shape[0]; ++i) { + for (const auto& varg : vargs) { + if ( std::holds_alternative(inputs[varg.first]) ) { + inputs[varg.first] = static_cast(varg.second.ptr)[i]; + } + else if ( std::holds_alternative(inputs[varg.first]) ) { + inputs[varg.first] = static_cast(varg.second.ptr)[i]; + } + } + outptr[i] = c.evaluate(inputs); + } + return output; }); py::class_(m, "CorrectionSet") diff --git a/tests/test_core_vect.py b/tests/test_core_vect.py new file mode 100644 index 00000000..4e704851 --- /dev/null +++ b/tests/test_core_vect.py @@ -0,0 +1,77 @@ +import numpy +import pytest + +import correctionlib._core as core +from correctionlib import schemav2 as schema + + +def wrap(*corrs): + cset = schema.CorrectionSet( + schema_version=schema.VERSION, + corrections=list(corrs), + ) + return core.CorrectionSet.from_string(cset.json()) + + +def test_core_vectorized(): + cset = wrap( + schema.Correction( + name="test", + version=1, + inputs=[ + schema.Variable(name="a", type="real"), + schema.Variable(name="b", type="int"), + schema.Variable(name="c", type="string"), + ], + output=schema.Variable(name="a scale", type="real"), + data={ + "nodetype": "category", + "input": "b", + "content": [ + { + "key": 1, + "value": { + "nodetype": "formula", + "expression": "x", + "parser": "TFormula", + "variables": ["a"], + }, + } + ], + "default": -99.0, + }, + ) + ) + corr = cset["test"] + + assert corr.evaluate(0.3, 1, "") == 0.3 + assert corr.evalv(0.3, 1, "") == 0.3 + numpy.testing.assert_array_equal( + corr.evalv(numpy.full(10, 0.3), 1, ""), + numpy.full(10, 0.3), + ) + numpy.testing.assert_array_equal( + corr.evalv(0.3, numpy.full(10, 1), ""), + numpy.full(10, 0.3), + ) + numpy.testing.assert_array_equal( + corr.evalv(numpy.full(10, 0.3), numpy.full(10, 1), ""), + numpy.full(10, 0.3), + ) + with pytest.raises(ValueError): + corr.evalv(numpy.full(5, 0.3), numpy.full(10, 1), "") + with pytest.raises(ValueError): + corr.evalv(numpy.full((10, 2), 0.3), 1, "") + with pytest.raises(ValueError): + corr.evalv(0.3) + with pytest.raises(ValueError): + corr.evalv(0.3, 1, 1, 1) + with pytest.raises(ValueError): + corr.evalv(0.3, 1, numpy.full(10, "asdf")) + + a = numpy.linspace(-3, 3, 100) + b = numpy.arange(100) % 3 + numpy.testing.assert_array_equal( + corr.evalv(a, b, ""), + numpy.where(b == 1, a, -99.0), + ) diff --git a/tests/test_highlevel.py b/tests/test_highlevel.py index 55fea92b..c66476b5 100644 --- a/tests/test_highlevel.py +++ b/tests/test_highlevel.py @@ -1,3 +1,4 @@ +import numpy import pytest import correctionlib @@ -12,7 +13,10 @@ def test_highlevel(): model.Correction( name="test corr", version=2, - inputs=[], + inputs=[ + model.Variable(name="a", type="real"), + model.Variable(name="b", type="real"), + ], output=model.Variable(name="a scale", type="real"), data=1.234, ) @@ -27,4 +31,12 @@ def test_highlevel(): with pytest.raises(RuntimeError): sf.evaluate(0, 1.2, 35.0, 0.01) - assert sf.evaluate() == 1.234 + assert sf.evaluate(1.0, 1.0) == 1.234 + numpy.testing.assert_array_equal( + sf.evaluate(numpy.ones((3, 4)), 1.0), + numpy.full((3, 4), 1.234), + ) + numpy.testing.assert_array_equal( + sf.evaluate(numpy.ones((3, 4)), numpy.ones(4)), + numpy.full((3, 4), 1.234), + ) From 6c4d761ef0458ce651222341e121294073d5e8ef Mon Sep 17 00:00:00 2001 From: Nicholas Smith Date: Tue, 27 Apr 2021 15:49:59 -0500 Subject: [PATCH 10/13] Make highlevel correction objects pickleable (#73) Closes #69 --- src/correctionlib/highlevel.py | 31 ++++++++++++++++++++++++------- tests/test_highlevel.py | 5 +++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/src/correctionlib/highlevel.py b/src/correctionlib/highlevel.py index a9624779..f032b031 100644 --- a/src/correctionlib/highlevel.py +++ b/src/correctionlib/highlevel.py @@ -3,7 +3,7 @@ """ import json from numbers import Integral -from typing import Any, Iterator, List, Mapping, Optional, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Union import numpy @@ -16,7 +16,7 @@ def open_auto(filename: str) -> Any: if filename.endswith(".json.gz"): import gzip - with gzip.open(filename, "r") as gzfile: + with gzip.open(filename, "rt") as gzfile: return json.load(gzfile) elif filename.endswith(".json"): with open(filename) as file: @@ -45,12 +45,22 @@ def model_auto(data: Any) -> Any: class Correction: - def __init__(self, base: correctionlib._core.Correction): + def __init__(self, base: correctionlib._core.Correction, context: "CorrectionSet"): self._base = base + self._name = base.name + self._context = context + + def __getstate__(self) -> Dict[str, Any]: + return {"_context": self._context, "_name": self._name} + + def __setstate__(self, state: Dict[str, Any]) -> None: + self._context = state["_context"] + self._name = state["_name"] + self._base = self._context[self._name]._base @property def name(self) -> str: - return self._base.name + return self._name @property def description(self) -> str: @@ -86,14 +96,14 @@ def __init__(self, model: Any, *, schema_version: Optional[int] = None): if model.schema_version < this_version: # TODO: upgrade schema automatically raise NotImplementedError( - "Cannot read CorrectionSet models older than {this_version}" + f"Cannot read CorrectionSet models older than {this_version}" ) elif schema_version != model.schema_version: raise ValueError( f"CorrectionSet schema version ({model.schema_version}) differs from desired version ({schema_version})" ) self._model = model - self._base = correctionlib._core.CorrectionSet.from_string(model.json()) + self._base = correctionlib._core.CorrectionSet.from_string(self._model.json()) @classmethod def from_file( @@ -107,6 +117,13 @@ def from_string( ) -> "CorrectionSet": return cls(model_auto(json.loads(data)), schema_version=schema_version) + def __getstate__(self) -> Dict[str, Any]: + return {"_model": self._model} + + def __setstate__(self, state: Dict[str, Any]) -> None: + self._model = state["_model"] + self._base = correctionlib._core.CorrectionSet.from_string(self._model.json()) + def _ipython_key_completions_(self) -> List[str]: return list(self.keys()) @@ -116,7 +133,7 @@ def schema_version(self) -> int: def __getitem__(self, key: str) -> Correction: corr = self._base.__getitem__(key) - return Correction(corr) + return Correction(corr, self) def __len__(self) -> int: return len(self._base) diff --git a/tests/test_highlevel.py b/tests/test_highlevel.py index c66476b5..03937647 100644 --- a/tests/test_highlevel.py +++ b/tests/test_highlevel.py @@ -1,3 +1,5 @@ +import pickle + import numpy import pytest @@ -40,3 +42,6 @@ def test_highlevel(): sf.evaluate(numpy.ones((3, 4)), numpy.ones(4)), numpy.full((3, 4), 1.234), ) + + sf2 = pickle.loads(pickle.dumps(sf)) + assert sf2.evaluate(1.0, 1.0) == 1.234 From f6d13a6a9ae9f4439d63288ae31b6468ef8a14a7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 28 Apr 2021 10:00:30 -0500 Subject: [PATCH 11/13] Bump pre-commit/action from v2.0.2 to v2.0.3 (#70) Bumps [pre-commit/action](https://github.com/pre-commit/action) from v2.0.2 to v2.0.3. - [Release notes](https://github.com/pre-commit/action/releases) - [Commits](https://github.com/pre-commit/action/compare/v2.0.2...9b88afc9cd57fd75b655d5c71bd38146d07135fe) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d9420c6b..9e0d63b3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ jobs: with: submodules: recursive - uses: actions/setup-python@v2 - - uses: pre-commit/action@v2.0.2 + - uses: pre-commit/action@v2.0.3 with: extra_args: --hook-stage manual --all-files From 4d69127ebff7c97a19fe52175dad0fd5866af683 Mon Sep 17 00:00:00 2001 From: Nicholas Smith Date: Mon, 3 May 2021 13:30:42 -0500 Subject: [PATCH 12/13] Some convenience utilities (#76) * Compilation and version info CLI utility * Quick-construct evaluator from schema, document highlevel * Need rpath for linux too * Catch unused vars * More accurate summary for formulas * Docs error * Add pyroot registration tool * Add link to cmake example --- README.md | 47 +++++++++++++++++---------------- docs/highlevel.rst | 11 ++++++++ docs/index.rst | 3 ++- src/correctionlib/__init__.py | 11 ++++++-- src/correctionlib/binding.py | 25 ++++++++++++++++++ src/correctionlib/cli.py | 48 ++++++++++++++++++++++++++++++++-- src/correctionlib/highlevel.py | 14 ++++++++++ src/correctionlib/schemav1.py | 2 +- src/correctionlib/schemav2.py | 23 ++++++++++++++-- 9 files changed, 154 insertions(+), 30 deletions(-) create mode 100644 docs/highlevel.rst create mode 100644 src/correctionlib/binding.py diff --git a/README.md b/README.md index a2a4864e..fe4d394e 100644 --- a/README.md +++ b/README.md @@ -70,38 +70,41 @@ The build process is based on setuptools, with CMake (through scikit-build) for the C++ evaluator and its python bindings module. Builds have been tested in Windows, OS X, and Linux, and python bindings can be compiled against both python2 and python3, as well as from within a CMSSW environment. The python bindings are distributed as a -pip-installable package. +pip-installable package. Note that CMSSW 11_2_X and above has ROOT accessible from python 3. -To build in an environment that has python 3, you can simply +To install in an environment that has python 3, you can simply ```bash -pip install correctionlib +python3 -m pip install correctionlib ``` (possibly with `--user`, or in a virtualenv, etc.) -Note that CMSSW 11_2_X and above has ROOT accessible from python 3. - -The C++ evaluator is part of the python package. If you are also using CMake you can depend on it by passing -`-Dcorrectionlib_DIR=$(python -c 'import pkg_resources; print(pkg_resources.resource_filename("correctionlib", "cmake"))')`. -The header and shared library can similarly be found as -```python -import pkg_resources -pkg_resources.resource_filename("correctionlib", "include/correction.h") -pkg_resources.resource_filename("correctionlib", "lib/libcorrectionlib.so") +If you wish to install the latest development version, +```bash +python3 -m pip install git+https://github.com/cms-nanoAOD/correctionlib.git ``` +should work. -In environments where no recent CMake is available, or if you want to build against python2 -(see below), you can build the C++ evaluator via: +The C++ evaluator library is distributed as part of the python package, and it can be +linked to directly without using python. If you are using CMake you can depend on it by including +the output of `correction config --cmake` in your cmake invocation. A complete cmake +example that builds a user C++ application against correctionlib and ROOT RDataFrame +can be [found here](https://gist.github.com/pieterdavid/a560e65658386d70a1720cb5afe4d3e9). + +For manual compilation, include and linking definitions can similarly be found via `correction config --cflags --ldflags`. +For example, the demo application can be compiled with: ```bash -git clone --recursive git@github.com:nsmith-/correctionlib.git -cd correctionlib -make -# demo C++ binding, main function at src/demo.cc -gunzip data/examples.json.gz -./demo data/examples.json +wget https://raw.githubusercontent.com/cms-nanoAOD/correctionlib/master/src/demo.cc +g++ $(correction config --cflags --ldflags --rpath) demo.cc -o demo ``` -Eventually a `correction-config` utility will be added to retrieve the header and linking flags. + +If the `correction` command-line utility is not on your path for some reason, it can also be invoked via `python -m correctionlib.cli`. To compile with python2 support, consider using python 3 :) If you considered that and still -want to use python2, follow the C++ build instructions and then call `make PYTHON=python2 correctionlib` to compile. +want to use python2, the following recipe may work: +```bash +git clone --recursive git@github.com:cms-nanoAOD/correctionlib.git +cd correctionlib +make PYTHON=python2 correctionlib +``` Inside CMSSW you should use `make PYTHON=python correctionlib` assuming `python` is the name of the scram tool you intend to link against. This will output a `correctionlib` directory that acts as a python package, and can be moved where needed. This package will only provide the `correctionlib._core` evaluator module, as the schema tools and high-level bindings are python3-only. diff --git a/docs/highlevel.rst b/docs/highlevel.rst new file mode 100644 index 00000000..0b044404 --- /dev/null +++ b/docs/highlevel.rst @@ -0,0 +1,11 @@ +correctionlib.highlevel +----------------------- +High-level interface to correction evaluator module. These objects +are also available directly through the ``correctionlib`` namespace. + +.. currentmodule:: correctionlib.highlevel +.. autosummary:: + :toctree: _generated + + Correction + CorrectionSet diff --git a/docs/index.rst b/docs/index.rst index 9feeb585..ed0cfe28 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -16,10 +16,11 @@ scalar inputs that produce a scalar output. :caption: Contents :glob: + highlevel schemav1 schemav2 - core convert + core Indices and tables diff --git a/src/correctionlib/__init__.py b/src/correctionlib/__init__.py index f4a781d6..95deeab0 100644 --- a/src/correctionlib/__init__.py +++ b/src/correctionlib/__init__.py @@ -4,10 +4,17 @@ import ctypes import os.path - ctypes.CDLL(os.path.join(os.path.dirname(__file__), "lib", "correctionlib.dll")) + import pkg_resources + ctypes.CDLL( + pkg_resources.resource_filename( + "correctionlib", os.path.join("lib", "correctionlib.dll") + ) + ) + +from .binding import register_pyroot_binding from .highlevel import Correction, CorrectionSet from .version import version as __version__ -__all__ = ("__version__", "CorrectionSet", "Correction") +__all__ = ("__version__", "CorrectionSet", "Correction", "register_pyroot_binding") diff --git a/src/correctionlib/binding.py b/src/correctionlib/binding.py new file mode 100644 index 00000000..8edcf21f --- /dev/null +++ b/src/correctionlib/binding.py @@ -0,0 +1,25 @@ +def register_pyroot_binding() -> None: + import os.path + import sys + + import pkg_resources + from cppyy import gbl # PyROOT without pythonization + + # maybe not the most robust solution? + if sys.platform.startswith("win32"): + lib = pkg_resources.resource_filename( + "correctionlib", os.path.join("lib", "correctionlib.dll") + ) + elif sys.platform.startswith("darwin"): + lib = pkg_resources.resource_filename( + "correctionlib", os.path.join("lib", "libcorrectionlib.dylib") + ) + else: + lib = pkg_resources.resource_filename( + "correctionlib", os.path.join("lib", "libcorrectionlib.so") + ) + gbl.gSystem.Load(lib) + gbl.gInterpreter.AddIncludePath( + pkg_resources.resource_filename("correctionlib", "include") + ) + gbl.gROOT.ProcessLine('#include "correction.h"') diff --git a/src/correctionlib/cli.py b/src/correctionlib/cli.py index 07f2ee93..08ca9736 100644 --- a/src/correctionlib/cli.py +++ b/src/correctionlib/cli.py @@ -6,6 +6,7 @@ from rich.console import Console +import correctionlib.version from correctionlib.highlevel import model_auto, open_auto @@ -117,6 +118,48 @@ def setup_merge(subparsers: argparse._SubParsersAction) -> None: parser.add_argument("files", nargs="+", metavar="FILE") +def config(console: Console, args: argparse.Namespace) -> int: + import pkg_resources + + incdir = pkg_resources.resource_filename("correctionlib", "include") + libdir = pkg_resources.resource_filename("correctionlib", "lib") + out = [] + if args.version: + out.append(correctionlib.version.version) + if args.incdir: + out.append(incdir) + if args.cflags: + out.append(f"-std=c++17 -I{incdir}") + if args.libdir: + out.append(libdir) + if args.ldflags: + out.append(f"-L{libdir} -lcorrectionlib") + if args.rpath: + out.append(f"-Wl,-rpath,{libdir}") + if args.cmake: + out.append( + f"-Dcorrectionlib_DIR={pkg_resources.resource_filename('correctionlib', 'cmake')}" + ) + console.out(" ".join(out), highlight=False) + return 0 + + +def setup_config(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser( + "config", help="Configuration and linking information" + ) + parser.set_defaults(command=config) + parser.add_argument("-v", "--version", action="store_true") + parser.add_argument("--incdir", action="store_true") + parser.add_argument("--cflags", action="store_true") + parser.add_argument("--libdir", action="store_true") + parser.add_argument("--ldflags", action="store_true") + parser.add_argument( + "--rpath", action="store_true", help="Include library path hint in linker" + ) + parser.add_argument("--cmake", action="store_true", help="CMake dependency flags") + + def main() -> int: parser = argparse.ArgumentParser(prog="correction", description=__doc__) parser.add_argument( @@ -125,14 +168,15 @@ def main() -> int: default=100, help="Rich output width", ) - parser.add_argument("--html", type=str, help="Save HTML output to a file") + parser.add_argument("--html", type=str, help="Save terminal output to an HTML file") subparsers = parser.add_subparsers() setup_validate(subparsers) setup_summary(subparsers) setup_merge(subparsers) + setup_config(subparsers) args = parser.parse_args() - console = Console(width=args.width, record=True) + console = Console(width=args.width, record=bool(args.html)) # py3.7: subparsers has required=True option if hasattr(args, "command"): retcode: int = args.command(console, args) diff --git a/src/correctionlib/highlevel.py b/src/correctionlib/highlevel.py index f032b031..290ade55 100644 --- a/src/correctionlib/highlevel.py +++ b/src/correctionlib/highlevel.py @@ -45,6 +45,12 @@ def model_auto(data: Any) -> Any: class Correction: + """High-level correction evaluator object + + This class is typically instantiated by accessing a named correction from + a CorrectionSet object, rather than directly by construction. + """ + def __init__(self, base: correctionlib._core.Correction, context: "CorrectionSet"): self._base = base self._name = base.name @@ -90,6 +96,14 @@ def evaluate( class CorrectionSet(Mapping[str, Correction]): + """High-level correction set evaluator object + + This class can be initialized directly from a model with compatible + schema version, or can be initialized via the ``from_file`` or + ``from_string`` factory methods. Corrections can be accessed + via getitem syntax, e.g. ``cset["some correction"]``. + """ + def __init__(self, model: Any, *, schema_version: Optional[int] = None): if schema_version is None: this_version = correctionlib.version.version_tuple[0] diff --git a/src/correctionlib/schemav1.py b/src/correctionlib/schemav1.py index 41d5e535..1fb2c9bc 100644 --- a/src/correctionlib/schemav1.py +++ b/src/correctionlib/schemav1.py @@ -51,7 +51,7 @@ class MultiBinning(Model): """Bin edges for each input C-ordered array, e.g. content[d1*d2*d3*i0 + d2*d3*i1 + d3*i2 + i3] corresponds - to the element at i0 in dimension 0, i1 in dimension 1, etc. and d0 = len(edges[0]), etc. + to the element at i0 in dimension 0, i1 in dimension 1, etc. and d0 = len(edges[0])-1, etc. """ content: List[Content] diff --git a/src/correctionlib/schemav2.py b/src/correctionlib/schemav2.py index 921d94eb..8cbe35c9 100644 --- a/src/correctionlib/schemav2.py +++ b/src/correctionlib/schemav2.py @@ -7,6 +7,8 @@ from rich.panel import Panel from rich.tree import Tree +import correctionlib.highlevel + try: from typing import Literal # type: ignore except ImportError: @@ -71,6 +73,9 @@ def summarize( self, nodecount: Dict[str, int], inputstats: Dict[str, _SummaryInfo] ) -> None: nodecount["Formula"] += 1 + for input in self.variables: + inputstats[input].min = float("-inf") + inputstats[input].max = float("inf") class FormulaRef(Model): @@ -173,7 +178,7 @@ class MultiBinning(Model): content: List[Content] = Field( description="""Bin contents as a flattened array This is a C-ordered array, i.e. content[d1*d2*d3*i0 + d2*d3*i1 + d3*i2 + i3] corresponds - to the element at i0 in dimension 0, i1 in dimension 1, etc. and d0 = len(edges[0]), etc. + to the element at i0 in dimension 0, i1 in dimension 1, etc. and d0 = len(edges[0])-1, etc. """ ) flow: Union[Content, Literal["clamp", "error"]] = Field( @@ -307,6 +312,9 @@ def summary(self) -> Tuple[Dict[str, int], Dict[str, _SummaryInfo]]: inputstats = {var.name: _SummaryInfo() for var in self.inputs} if not isinstance(self.data, float): self.data.summarize(nodecount, inputstats) + if self.generic_formulas: + for formula in self.generic_formulas: + formula.summarize(nodecount, inputstats) return nodecount, inputstats def __rich_console__( @@ -322,7 +330,10 @@ def __rich_console__( def fmt_input(var: Variable, stats: _SummaryInfo) -> str: out = var.__rich__() if var.type == "real": - out += f"\nRange: [{stats.min}, {stats.max})" + if stats.min == float("inf") and stats.max == float("-inf"): + out += "\nRange: [bold red]unused[/bold red]" + else: + out += f"\nRange: [{stats.min}, {stats.max})" if stats.overflow: out += ", overflow ok" if stats.transform: @@ -347,6 +358,11 @@ def fmt_input(var: Variable, stats: _SummaryInfo) -> str: expand=False, ) + def to_evaluator(self) -> correctionlib.highlevel.Correction: + # TODO: consider refactoring highlevel.Correction to be independent + cset = CorrectionSet(schema_version=VERSION, corrections=[self]) + return correctionlib.highlevel.CorrectionSet(cset)[self.name] + class CorrectionSet(Model): schema_version: Literal[VERSION] = Field(description="The overall schema version") @@ -362,6 +378,9 @@ def __rich_console__( tree.add(corr) yield tree + def to_evaluator(self) -> correctionlib.highlevel.CorrectionSet: + return correctionlib.highlevel.CorrectionSet(self) + if __name__ == "__main__": import os From c1b70c580d9872faf0444c447588481124146ef2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 May 2021 14:27:50 -0500 Subject: [PATCH 13/13] Bump joerick/cibuildwheel from v1.10.0 to v1.11.0 (#77) Bumps [joerick/cibuildwheel](https://github.com/joerick/cibuildwheel) from v1.10.0 to v1.11.0. - [Release notes](https://github.com/joerick/cibuildwheel/releases) - [Changelog](https://github.com/joerick/cibuildwheel/blob/master/docs/changelog.md) - [Commits](https://github.com/joerick/cibuildwheel/compare/v1.10.0...aa12480ff0e5381eca2258a6957aea6af5c46172) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Nicholas Smith --- .github/workflows/wheels.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index d98cd0f3..2cee0067 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -35,7 +35,7 @@ jobs: with: submodules: recursive - - uses: joerick/cibuildwheel@v1.10.0 + - uses: joerick/cibuildwheel@v1.11.0 env: CIBW_SKIP: cp27* CIBW_TEST_EXTRAS: test @@ -61,7 +61,7 @@ jobs: with: arch: x86 - - uses: joerick/cibuildwheel@v1.10.0 + - uses: joerick/cibuildwheel@v1.11.0 env: CIBW_BUILD: cp3*-win32 CIBW_TEST_EXTRAS: test @@ -71,7 +71,7 @@ jobs: with: arch: x64 - - uses: joerick/cibuildwheel@v1.10.0 + - uses: joerick/cibuildwheel@v1.11.0 env: CIBW_BUILD: cp3*-win_amd64 CIBW_TEST_EXTRAS: test