From 5bfd1a33f980fc3c12cec2105d75c0cc1434b5c5 Mon Sep 17 00:00:00 2001 From: Nicholas Smith Date: Thu, 1 Feb 2024 16:57:27 -0600 Subject: [PATCH] Migrate to pydantic2 (#220) * Migrate to pydantic2 * Add targeted test for .to_evaluator() * Migrate fully to pydantic 2 Resolve all pydantic deprecation warnings Add some more model validation tests (and fix a bug in C++ validation!) --- .pre-commit-config.yaml | 2 +- data/conversion.py | 28 +++++----- setup.cfg | 2 +- src/correction.cc | 6 +-- src/correctionlib/JSONEncoder.py | 5 +- src/correctionlib/cli.py | 4 +- src/correctionlib/convert.py | 10 ++-- src/correctionlib/highlevel.py | 6 +-- src/correctionlib/schemav1.py | 9 ++-- src/correctionlib/schemav2.py | 87 ++++++++++++++++++++------------ tests/test_compound.py | 4 +- tests/test_core.py | 10 ++-- tests/test_core_valid.py | 16 +++++- tests/test_core_vect.py | 2 +- tests/test_formula_ast.py | 3 +- tests/test_hashprng.py | 2 +- tests/test_highlevel.py | 25 +++++++++ tests/test_unique.py | 6 +-- 18 files changed, 142 insertions(+), 85 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 08e9f534..d9ff6837 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,4 +41,4 @@ repos: hooks: - id: mypy files: src - additional_dependencies: [uhi, pydantic<2, numpy, types-setuptools] + additional_dependencies: [uhi, pydantic, numpy, types-setuptools] diff --git a/data/conversion.py b/data/conversion.py index cfd24379..f9b0d471 100755 --- a/data/conversion.py +++ b/data/conversion.py @@ -33,7 +33,7 @@ def build_formula(sf): value = sf.iloc[0]["formula"] if "x" in value: - return Formula.parse_obj( + return Formula.model_validate( { "nodetype": "formula", "expression": value, @@ -49,7 +49,7 @@ def build_formula(sf): def build_discrbinning(sf): edges = sorted(set(sf["discrMin"]) | set(sf["discrMax"])) - return Binning.parse_obj( + return Binning.model_validate( { "nodetype": "binning", "input": "discriminant", @@ -65,7 +65,7 @@ def build_discrbinning(sf): def build_ptbinning(sf): edges = sorted(set(sf["ptMin"]) | set(sf["ptMax"])) - return Binning.parse_obj( + return Binning.model_validate( { "nodetype": "binning", "input": "pt", @@ -81,7 +81,7 @@ def build_ptbinning(sf): def build_etabinning(sf): edges = sorted(set(sf["etaMin"]) | set(sf["etaMax"])) - return Binning.parse_obj( + return Binning.model_validate( { "nodetype": "binning", "input": "abseta", @@ -97,7 +97,7 @@ def build_etabinning(sf): def build_flavor(sf): keys = sorted(sf["jetFlavor"].unique()) - return Category.parse_obj( + return Category.model_validate( { "nodetype": "category", "input": "flavor", @@ -111,7 +111,7 @@ def build_flavor(sf): def build_systs(sf): keys = list(sf["sysType"].unique()) - return Category.parse_obj( + return Category.model_validate( { "nodetype": "category", "input": "systematic", @@ -123,7 +123,7 @@ def build_systs(sf): ) -corr2 = Correction.parse_obj( +corr2 = Correction.model_validate( { "version": 1, "name": "DeepCSV_2016LegacySF", @@ -149,11 +149,11 @@ def build_systs(sf): ) -sf = requests.get(f"{examples}/EIDISO_WH_out.histo.json").json() +sf = requests.get(f"{examples}/EIDISO_WH_out.histo.json").model_dump_json() def build_syst(sf): - return Category.parse_obj( + return Category.model_validate( { "nodetype": "category", "input": "systematic", @@ -187,7 +187,7 @@ def build_pts(sf): edges.append(hi) content.append(build_syst(data)) - return Binning.parse_obj( + return Binning.model_validate( { "nodetype": "binning", "input": "pt", @@ -212,7 +212,7 @@ def build_etas(sf): if not found: raise ValueError("eta edges not in binning?") - return Binning.parse_obj( + return Binning.model_validate( { "nodetype": "binning", "input": "eta", @@ -223,7 +223,7 @@ def build_etas(sf): ) -corr3 = Correction.parse_obj( +corr3 = Correction.model_validate( { "version": 1, "name": "EIDISO_WH_out", @@ -239,7 +239,7 @@ def build_etas(sf): ) -cset = CorrectionSet.parse_obj( +cset = CorrectionSet.model_validate( { "schema_version": VERSION, "corrections": [ @@ -251,4 +251,4 @@ def build_etas(sf): ) with gzip.open("data/examples.json.gz", "wt") as fout: - fout.write(cset.json(exclude_unset=True)) + fout.write(cset.model_dump_json(exclude_unset=True)) diff --git a/setup.cfg b/setup.cfg index 129a53a1..e374ca8b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,7 +29,7 @@ install_requires = typing; python_version<"3.5" typing-extensions;python_version<"3.8" dataclasses;python_version<"3.7" - pydantic <2,>=1.7.3 + pydantic >=2 rich packaging diff --git a/src/correction.cc b/src/correction.cc index 03894cb2..4757f5e1 100644 --- a/src/correction.cc +++ b/src/correction.cc @@ -103,7 +103,7 @@ std::string_view JSONObject::getRequired(const char * key) con namespace { Content resolve_content(const rapidjson::Value& json, const Correction& context) { - if ( json.IsDouble() ) { return json.GetDouble(); } + if ( json.IsNumber() ) { return json.GetDouble(); } else if ( json.IsObject() && json.HasMember("nodetype") ) { auto obj = JSONObject(json.GetObject()); auto type = obj.getRequired("nodetype"); @@ -569,13 +569,13 @@ Category::Category(const JSONObject& json, const Correction& context) } if ( kv_pair["key"].IsString() ) { if ( variable.type() != Variable::VarType::string ) { - throw std::runtime_error("Category got a key not of type string, but its input is string type"); + throw std::runtime_error("Category got a key of type string, but its input is type " + variable.typeStr()); } std::get(map_).try_emplace(kv_pair["key"].GetString(), resolve_content(kv_pair["value"], context)); } else if ( kv_pair["key"].IsInt() ) { if ( variable.type() != Variable::VarType::integer ) { - throw std::runtime_error("Category got a key not of type int, but its input is int type"); + throw std::runtime_error("Category got a key of type int, but its input is type " + variable.typeStr()); } std::get(map_).try_emplace(kv_pair["key"].GetInt(), resolve_content(kv_pair["value"], context)); } diff --git a/src/correctionlib/JSONEncoder.py b/src/correctionlib/JSONEncoder.py index bc925c1a..5ff0d272 100755 --- a/src/correctionlib/JSONEncoder.py +++ b/src/correctionlib/JSONEncoder.py @@ -29,9 +29,8 @@ def write(data: Any, fname: str, **kwargs: Any) -> None: def dumps(data: Any, sort_keys: bool = False, **kwargs: Any) -> str: """Help function to quickly dump dictionary formatted by JSONEncoder.""" if isinstance(data, pydantic.BaseModel): # for pydantic - return data.json(cls=JSONEncoder, exclude_unset=True, **kwargs) - else: # for standard data structures - return json.dumps(data, cls=JSONEncoder, sort_keys=sort_keys, **kwargs) + data = data.model_dump(mode="json", exclude_unset=True) + return json.dumps(data, cls=JSONEncoder, sort_keys=sort_keys, **kwargs) class JSONEncoder(json.JSONEncoder): diff --git a/src/correctionlib/cli.py b/src/correctionlib/cli.py index 85742e88..16258673 100644 --- a/src/correctionlib/cli.py +++ b/src/correctionlib/cli.py @@ -101,9 +101,9 @@ def merge(console: Console, args: argparse.Namespace) -> int: cset.compound_corrections.append(corr2) cset.description = "Merged from " + " ".join(args.files) if args.format == "compact": - sys.stdout.write(cset.json()) + sys.stdout.write(cset.model_dump_json()) elif args.format == "indented": - sys.stdout.write(cset.json(indent=4) + "\n") + sys.stdout.write(cset.model_dump_json(indent=4) + "\n") elif args.format == "pretty": from correctionlib.JSONEncoder import dumps diff --git a/src/correctionlib/convert.py b/src/correctionlib/convert.py index 489c6b60..ad18194e 100644 --- a/src/correctionlib/convert.py +++ b/src/correctionlib/convert.py @@ -82,7 +82,7 @@ def read_axis(axis: "PlottableAxis", pos: int) -> Variable: axname = getattr( axis, "name", f"axis{pos}" if axis_names is None else axis_names[pos] ) - return Variable.parse_obj( + return Variable.model_validate( { "type": axtype, "name": axname, @@ -120,7 +120,7 @@ def build_data( ) -> Content: vartype = variables[0].type if vartype in {"string", "int"}: - return Category.parse_obj( + return Category.model_validate( { "nodetype": "category", "input": variables[0].name, @@ -142,7 +142,7 @@ def build_data( break i += 1 if i > 1: - return MultiBinning.parse_obj( + return MultiBinning.model_validate( { "nodetype": "multibinning", "edges": [edges(ax) for ax in axes[:i]], @@ -156,7 +156,7 @@ def build_data( "flow": flow, } ) - return Binning.parse_obj( + return Binning.model_validate( { "nodetype": "binning", "input": variables[0].name, @@ -171,7 +171,7 @@ def build_data( } ) - return Correction.parse_obj( + return Correction.model_validate( { "version": 0, "name": getattr(hist, "name", "unknown"), diff --git a/src/correctionlib/highlevel.py b/src/correctionlib/highlevel.py index 63ec25e0..52163551 100644 --- a/src/correctionlib/highlevel.py +++ b/src/correctionlib/highlevel.py @@ -40,11 +40,11 @@ def model_auto(data: str) -> Any: if version == 1: import correctionlib.schemav1 - return correctionlib.schemav1.CorrectionSet.parse_obj(data) + return correctionlib.schemav1.CorrectionSet.model_validate(data) elif version == 2: import correctionlib.schemav2 - return correctionlib.schemav2.CorrectionSet.parse_obj(data) + return correctionlib.schemav2.CorrectionSet.model_validate(data) raise ValueError(f"Unknown CorrectionSet schema version ({version})") @@ -292,7 +292,7 @@ def __init__(self, data: Any): if isinstance(data, str): self._data = data else: - self._data = data.json(exclude_unset=True) + self._data = data.model_dump_json(exclude_unset=True) self._base = correctionlib._core.CorrectionSet.from_string(self._data) @classmethod diff --git a/src/correctionlib/schemav1.py b/src/correctionlib/schemav1.py index 250aac42..47d0b6f6 100644 --- a/src/correctionlib/schemav1.py +++ b/src/correctionlib/schemav1.py @@ -1,7 +1,7 @@ import sys from typing import List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict if sys.version_info >= (3, 8): from typing import Literal @@ -13,15 +13,14 @@ class Model(BaseModel): - class Config: - extra = "forbid" + model_config = ConfigDict(extra="forbid") class Variable(Model): name: str type: Literal["string", "int", "real"] "Implicitly 64 bit integer and double-precision floating point?" - description: Optional[str] + description: Optional[str] = None # TODO: clamping behavior for out of range? @@ -72,7 +71,7 @@ class Category(Model): class Correction(Model): name: str "A useful name" - description: Optional[str] + description: Optional[str] = None "Detailed description of the correction" version: int "Version" diff --git a/src/correctionlib/schemav2.py b/src/correctionlib/schemav2.py index c2d716df..060cc5d0 100644 --- a/src/correctionlib/schemav2.py +++ b/src/correctionlib/schemav2.py @@ -1,8 +1,16 @@ import sys from collections import defaultdict -from typing import Any, Dict, List, Optional, Set, Tuple, Union - -from pydantic import BaseModel, Field, StrictInt, StrictStr, validator +from typing import Dict, List, Optional, Set, Tuple, Union + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + StrictInt, + StrictStr, + ValidationInfo, + field_validator, +) from rich.columns import Columns from rich.console import Console, ConsoleOptions, RenderResult from rich.panel import Panel @@ -20,8 +28,7 @@ class Model(BaseModel): - class Config: - extra = "forbid" + model_config = ConfigDict(extra="forbid") class _SummaryInfo: @@ -140,7 +147,7 @@ class HashPRNG(Model): nodetype: Literal["hashprng"] inputs: List[str] = Field( description="The names of the input variables to use as entropy sources", - min_items=1, + min_length=1, ) distribution: Literal["stdflat", "stdnormal", "normal"] = Field( description="The output distribution to draw from" @@ -159,15 +166,17 @@ class UniformBinning(Model): low: float = Field(description="Lower edge of first bin") high: float = Field(description="Higher edge of last bin") - @validator("n") + @field_validator("n") + @classmethod def validate_n(cls, n: int) -> int: if n <= 0: # downstream C++ logic assumes there is at least one bin raise ValueError(f"Number of bins must be greater than 0, got {n}") return n - @validator("high") - def validate_edges(cls, high: float, values: Any) -> float: - low = values["low"] + @field_validator("high") + @classmethod + def validate_edges(cls, high: float, info: ValidationInfo) -> float: + low = info.data["low"] if low >= high: raise ValueError( f"Higher bin edge must be larger than lower, got {[low, high]}" @@ -190,7 +199,8 @@ class Binning(Model): description="Overflow behavior for out-of-bounds values" ) - @validator("edges") + @field_validator("edges") + @classmethod def validate_edges( cls, edges: Union[List[float], UniformBinning] ) -> Union[List[float], UniformBinning]: @@ -203,13 +213,16 @@ def validate_edges( return edges - @validator("content") - def validate_content(cls, content: List[Content], values: Any) -> List[Content]: - assert "edges" in values - if isinstance(values["edges"], list): - nbins = len(values["edges"]) - 1 + @field_validator("content") + @classmethod + def validate_content( + cls, content: List[Content], info: ValidationInfo + ) -> List[Content]: + assert "edges" in info.data + if isinstance(info.data["edges"], list): + nbins = len(info.data["edges"]) - 1 else: - nbins = values["edges"].n + nbins = info.data["edges"].n if nbins != len(content): raise ValueError( f"Binning content length ({len(content)}) is not one less than edges ({nbins + 1})" @@ -238,7 +251,7 @@ class MultiBinning(Model): nodetype: Literal["multibinning"] inputs: List[str] = Field( description="The names of the correction input variables this binning applies to", - min_items=1, + min_length=1, ) edges: List[Union[List[float], UniformBinning]] = Field( description="Bin edges for each input" @@ -253,9 +266,10 @@ class MultiBinning(Model): description="Overflow behavior for out-of-bounds values" ) - @validator("edges") + @field_validator("edges") + @classmethod def validate_edges( - cls, edges: List[Union[List[float], UniformBinning]], values: Any + cls, edges: List[Union[List[float], UniformBinning]] ) -> List[Union[List[float], UniformBinning]]: for i, dim in enumerate(edges): if isinstance(dim, list): @@ -266,11 +280,14 @@ def validate_edges( ) return edges - @validator("content") - def validate_content(cls, content: List[Content], values: Any) -> List[Content]: - assert "edges" in values + @field_validator("content") + @classmethod + def validate_content( + cls, content: List[Content], info: ValidationInfo + ) -> List[Content]: + assert "edges" in info.data nbins = 1 - for dim in values["edges"]: + for dim in info.data["edges"]: if isinstance(dim, list): nbins *= len(dim) - 1 else: @@ -318,7 +335,8 @@ class Category(Model): content: List[CategoryItem] default: Optional[Content] = None - @validator("content") + @field_validator("content") + @classmethod def validate_content(cls, content: List[CategoryItem]) -> List[CategoryItem]: if len(content): keytype = type(content[0].key) @@ -345,11 +363,11 @@ def summarize( self.default.summarize(nodecount, inputstats) -Transform.update_forward_refs() -Binning.update_forward_refs() -MultiBinning.update_forward_refs() -CategoryItem.update_forward_refs() -Category.update_forward_refs() +Transform.model_rebuild() +Binning.model_rebuild() +MultiBinning.model_rebuild() +CategoryItem.model_rebuild() +Category.model_rebuild() class Correction(Model): @@ -377,7 +395,8 @@ class Correction(Model): ) data: Content = Field(description="The root content node") - @validator("output") + @field_validator("output") + @classmethod def validate_output(cls, output: Variable) -> Variable: if output.type != "real": raise ValueError( @@ -512,7 +531,8 @@ class CorrectionSet(Model): corrections: List[Correction] compound_corrections: Optional[List[CompoundCorrection]] = None - @validator("corrections") + @field_validator("corrections") + @classmethod def validate_corrections(cls, items: List[Correction]) -> List[Correction]: seen = set() dupe = set() @@ -526,7 +546,8 @@ def validate_corrections(cls, items: List[Correction]) -> List[Correction]: ) return items - @validator("compound_corrections") + @field_validator("compound_corrections") + @classmethod def validate_compound( cls, items: Optional[List[CompoundCorrection]] ) -> Optional[List[CompoundCorrection]]: diff --git a/tests/test_compound.py b/tests/test_compound.py index 0c3a764c..f4328092 100644 --- a/tests/test_compound.py +++ b/tests/test_compound.py @@ -6,7 +6,7 @@ def test_compound(): - cset = correctionlib.schemav2.CorrectionSet.parse_obj( + cset = correctionlib.schemav2.CorrectionSet.model_validate( { "schema_version": 2, "corrections": [ @@ -63,7 +63,7 @@ def test_compound(): ], } ) - cset = correctionlib.CorrectionSet.from_string(cset.json()) + cset = correctionlib.CorrectionSet.from_string(cset.model_dump_json()) corr = cset.compound["l1l2"] assert corr.evaluate(10.0, 1.2) == 1 + 0.1 * math.log10(10 * 1.1) + 0.1 * 1.2 assert corr.evaluate(10.0, 0.0) == 1 + 0.1 * math.log10(10 * 1.1) diff --git a/tests/test_core.py b/tests/test_core.py index 595eb7a4..defede27 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -13,7 +13,7 @@ def wrap(*corrs): schema_version=schema.VERSION, corrections=list(corrs), ) - return core.CorrectionSet.from_string(cset.json()) + return core.CorrectionSet.from_string(cset.model_dump_json()) def test_evaluator(): @@ -63,14 +63,14 @@ def test_evaluator(): schema.Variable(name="syst", type="string"), ], output=schema.Variable(name="a scale", type="real"), - data=schema.Binning.parse_obj( + data=schema.Binning.model_validate( { "nodetype": "binning", "input": "pt", "edges": [0, 20, 40, float("inf")], "flow": "error", "content": [ - schema.Category.parse_obj( + schema.Category.model_validate( { "nodetype": "category", "input": "syst", @@ -80,7 +80,7 @@ def test_evaluator(): ], } ), - schema.Category.parse_obj( + schema.Category.model_validate( { "nodetype": "category", "input": "syst", @@ -164,7 +164,7 @@ def evaluate(expr, variables, parameters): } ], } - schema.CorrectionSet.parse_obj(cset) + schema.CorrectionSet.model_validate(cset) corr = core.CorrectionSet.from_string(json.dumps(cset))["test"] return corr.evaluate(*variables) diff --git a/tests/test_core_valid.py b/tests/test_core_valid.py index 1b1cd765..3a14bf34 100644 --- a/tests/test_core_valid.py +++ b/tests/test_core_valid.py @@ -1,5 +1,7 @@ import pytest +from pydantic import ValidationError +from correctionlib import schemav2 as schema from correctionlib._core import CorrectionSet @@ -25,18 +27,28 @@ def test_evaluator_validation(): '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "string"}}]}', '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "string"}, "inputs": []}]}', '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "string"}, "inputs": [], "data": 1}]}', - '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "real"}, "inputs": [], "data": 1}]}', '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "real"}, "inputs": [], "data": {"nodetype": 3}}]}', '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "real"}, "inputs": [], "data": {"nodetype": "blah"}}]}', '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "real"}, "inputs": [], "data": {"nodetype": "category", "input": "blah"}}]}', '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "real"}, "inputs": [{"name":"blah", "type": "int"}], "data": {"nodetype": "category", "input": "blah", "content": [3]}}]}', '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "real"}, "inputs": [{"name":"blah", "type": "int"}], "data": {"nodetype": "category", "input": "blah", "content": [{"key": null, "value": 3}]}}]}', '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "real"}, "inputs": [{"name":"blah", "type": "int"}], "data": {"nodetype": "category", "input": "blah", "content": [{"key": 1.2, "val": 3}]}}]}', - '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "real"}, "inputs": [{"name":"blah", "type": "int"}], "data": {"nodetype": "category", "input": "blah", "content": [{"key": "a", "value": 3.0}]}}]}', '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "real"}, "inputs": [{"name":"blah", "type": "int"}], "data": {"nodetype": "category", "input": "blah", "content": [{"key": 1, "value": 3.0}], "default": "f"}}]}', '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "real"}, "inputs": [{"name":"blah", "type": "int"}], "data": {"nodetype": "formula"}}]}', ] for json in bad_json: + with pytest.raises(ValidationError): + schema.CorrectionSet.model_validate_json(json) + pytest.fail(f"{json} did not fail validation") with pytest.raises(RuntimeError): CorrectionSet.from_string(json) + + # TODO: this is not detected by the pydantic model yet + hard_json = '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "real"}, "inputs": [{"name":"blah", "type": "int"}], "data": {"nodetype": "category", "input": "blah", "content": [{"key": "a", "value": 3.0}]}}]}' + with pytest.raises(RuntimeError): + CorrectionSet.from_string(hard_json) + + # Was previously invalid in the C++ parser, but now considered valid as JSON Schema does not reject numeric types without a decimal + good_json = '{"schema_version":2, "corrections": [{"name": "hi", "version": 2, "output": {"name": "hi","type": "real"}, "inputs": [], "data": 1}]}' + CorrectionSet.from_string(good_json) diff --git a/tests/test_core_vect.py b/tests/test_core_vect.py index 4e704851..e6eec42f 100644 --- a/tests/test_core_vect.py +++ b/tests/test_core_vect.py @@ -10,7 +10,7 @@ def wrap(*corrs): schema_version=schema.VERSION, corrections=list(corrs), ) - return core.CorrectionSet.from_string(cset.json()) + return core.CorrectionSet.from_string(cset.model_dump_json()) def test_core_vectorized(): diff --git a/tests/test_formula_ast.py b/tests/test_formula_ast.py index c95c8173..6db99392 100644 --- a/tests/test_formula_ast.py +++ b/tests/test_formula_ast.py @@ -19,7 +19,8 @@ def test_access_formula_ast(): ) formula = core.Formula.from_string( - c.data.json(), [core.Variable.from_string(v.json()) for v in c.inputs] + c.data.model_dump_json(), + [core.Variable.from_string(v.model_dump_json()) for v in c.inputs], ) ast = formula.ast diff --git a/tests/test_hashprng.py b/tests/test_hashprng.py index ece79d93..54cd446d 100644 --- a/tests/test_hashprng.py +++ b/tests/test_hashprng.py @@ -3,7 +3,7 @@ def test_hashprng(): - cset = correctionlib.schemav2.CorrectionSet.parse_obj( + cset = correctionlib.schemav2.CorrectionSet.model_validate( { "schema_version": 2, "corrections": [ diff --git a/tests/test_highlevel.py b/tests/test_highlevel.py index e86e61a0..79d617f3 100644 --- a/tests/test_highlevel.py +++ b/tests/test_highlevel.py @@ -106,3 +106,28 @@ def test_highlevel_dask(cset): awkward.flatten(evaluate).compute(), numpy.full(6, 1.234), ) + + +def test_model_to_evaluator(): + m = model.CorrectionSet( + schema_version=model.VERSION, + corrections=[ + model.Correction( + name="test corr", + version=2, + inputs=[ + model.Variable(name="a", type="real"), + model.Variable(name="b", type="real"), + ], + output=model.Variable(name="a scale", type="real"), + data=1.234, + ) + ], + ) + cset = m.to_evaluator() + assert set(cset) == {"test corr"} + + sf = m.corrections[0].to_evaluator() + assert sf.version == 2 + assert sf.description == "" + assert sf.evaluate(1.0, 1.0) == 1.234 diff --git a/tests/test_unique.py b/tests/test_unique.py index 7f6a4e82..6b867909 100644 --- a/tests/test_unique.py +++ b/tests/test_unique.py @@ -25,7 +25,7 @@ def makec(name): "stack": [], } - correctionlib.schemav2.CorrectionSet.parse_obj( + correctionlib.schemav2.CorrectionSet.model_validate( { "schema_version": 2, "corrections": [make("thing1"), make("thing2")], @@ -34,12 +34,12 @@ def makec(name): ) with pytest.raises(ValueError): - correctionlib.schemav2.CorrectionSet.parse_obj( + correctionlib.schemav2.CorrectionSet.model_validate( {"schema_version": 2, "corrections": [make("thing1"), make("thing1")]} ) with pytest.raises(ValueError): - correctionlib.schemav2.CorrectionSet.parse_obj( + correctionlib.schemav2.CorrectionSet.model_validate( { "schema_version": 2, "corrections": [make("thing1"), make("thing1")],