Skip to content

Commit

Permalink
Migrate to pydantic2 (#220)
Browse files Browse the repository at this point in the history
* Migrate to pydantic2

* Add targeted test for .to_evaluator()

* Migrate fully to pydantic 2

Resolve all pydantic deprecation warnings
Add some more model validation tests (and fix a bug in C++ validation!)
  • Loading branch information
nsmith- authored Feb 1, 2024
1 parent 0d82d12 commit 5bfd1a3
Show file tree
Hide file tree
Showing 18 changed files with 142 additions and 85 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ repos:
hooks:
- id: mypy
files: src
additional_dependencies: [uhi, pydantic<2, numpy, types-setuptools]
additional_dependencies: [uhi, pydantic, numpy, types-setuptools]
28 changes: 14 additions & 14 deletions data/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def build_formula(sf):

value = sf.iloc[0]["formula"]
if "x" in value:
return Formula.parse_obj(
return Formula.model_validate(
{
"nodetype": "formula",
"expression": value,
Expand All @@ -49,7 +49,7 @@ def build_formula(sf):

def build_discrbinning(sf):
edges = sorted(set(sf["discrMin"]) | set(sf["discrMax"]))
return Binning.parse_obj(
return Binning.model_validate(
{
"nodetype": "binning",
"input": "discriminant",
Expand All @@ -65,7 +65,7 @@ def build_discrbinning(sf):

def build_ptbinning(sf):
edges = sorted(set(sf["ptMin"]) | set(sf["ptMax"]))
return Binning.parse_obj(
return Binning.model_validate(
{
"nodetype": "binning",
"input": "pt",
Expand All @@ -81,7 +81,7 @@ def build_ptbinning(sf):

def build_etabinning(sf):
edges = sorted(set(sf["etaMin"]) | set(sf["etaMax"]))
return Binning.parse_obj(
return Binning.model_validate(
{
"nodetype": "binning",
"input": "abseta",
Expand All @@ -97,7 +97,7 @@ def build_etabinning(sf):

def build_flavor(sf):
keys = sorted(sf["jetFlavor"].unique())
return Category.parse_obj(
return Category.model_validate(
{
"nodetype": "category",
"input": "flavor",
Expand All @@ -111,7 +111,7 @@ def build_flavor(sf):

def build_systs(sf):
keys = list(sf["sysType"].unique())
return Category.parse_obj(
return Category.model_validate(
{
"nodetype": "category",
"input": "systematic",
Expand All @@ -123,7 +123,7 @@ def build_systs(sf):
)


corr2 = Correction.parse_obj(
corr2 = Correction.model_validate(
{
"version": 1,
"name": "DeepCSV_2016LegacySF",
Expand All @@ -149,11 +149,11 @@ def build_systs(sf):
)


sf = requests.get(f"{examples}/EIDISO_WH_out.histo.json").json()
sf = requests.get(f"{examples}/EIDISO_WH_out.histo.json").model_dump_json()


def build_syst(sf):
return Category.parse_obj(
return Category.model_validate(
{
"nodetype": "category",
"input": "systematic",
Expand Down Expand Up @@ -187,7 +187,7 @@ def build_pts(sf):
edges.append(hi)
content.append(build_syst(data))

return Binning.parse_obj(
return Binning.model_validate(
{
"nodetype": "binning",
"input": "pt",
Expand All @@ -212,7 +212,7 @@ def build_etas(sf):
if not found:
raise ValueError("eta edges not in binning?")

return Binning.parse_obj(
return Binning.model_validate(
{
"nodetype": "binning",
"input": "eta",
Expand All @@ -223,7 +223,7 @@ def build_etas(sf):
)


corr3 = Correction.parse_obj(
corr3 = Correction.model_validate(
{
"version": 1,
"name": "EIDISO_WH_out",
Expand All @@ -239,7 +239,7 @@ def build_etas(sf):
)


cset = CorrectionSet.parse_obj(
cset = CorrectionSet.model_validate(
{
"schema_version": VERSION,
"corrections": [
Expand All @@ -251,4 +251,4 @@ def build_etas(sf):
)

with gzip.open("data/examples.json.gz", "wt") as fout:
fout.write(cset.json(exclude_unset=True))
fout.write(cset.model_dump_json(exclude_unset=True))
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ install_requires =
typing; python_version<"3.5"
typing-extensions;python_version<"3.8"
dataclasses;python_version<"3.7"
pydantic <2,>=1.7.3
pydantic >=2
rich
packaging

Expand Down
6 changes: 3 additions & 3 deletions src/correction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ std::string_view JSONObject::getRequired<std::string_view>(const char * key) con

namespace {
Content resolve_content(const rapidjson::Value& json, const Correction& context) {
if ( json.IsDouble() ) { return json.GetDouble(); }
if ( json.IsNumber() ) { return json.GetDouble(); }
else if ( json.IsObject() && json.HasMember("nodetype") ) {
auto obj = JSONObject(json.GetObject());
auto type = obj.getRequired<std::string_view>("nodetype");
Expand Down Expand Up @@ -569,13 +569,13 @@ Category::Category(const JSONObject& json, const Correction& context)
}
if ( kv_pair["key"].IsString() ) {
if ( variable.type() != Variable::VarType::string ) {
throw std::runtime_error("Category got a key not of type string, but its input is string type");
throw std::runtime_error("Category got a key of type string, but its input is type " + variable.typeStr());
}
std::get<StrMap>(map_).try_emplace(kv_pair["key"].GetString(), resolve_content(kv_pair["value"], context));
}
else if ( kv_pair["key"].IsInt() ) {
if ( variable.type() != Variable::VarType::integer ) {
throw std::runtime_error("Category got a key not of type int, but its input is int type");
throw std::runtime_error("Category got a key of type int, but its input is type " + variable.typeStr());
}
std::get<IntMap>(map_).try_emplace(kv_pair["key"].GetInt(), resolve_content(kv_pair["value"], context));
}
Expand Down
5 changes: 2 additions & 3 deletions src/correctionlib/JSONEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ def write(data: Any, fname: str, **kwargs: Any) -> None:
def dumps(data: Any, sort_keys: bool = False, **kwargs: Any) -> str:
"""Help function to quickly dump dictionary formatted by JSONEncoder."""
if isinstance(data, pydantic.BaseModel): # for pydantic
return data.json(cls=JSONEncoder, exclude_unset=True, **kwargs)
else: # for standard data structures
return json.dumps(data, cls=JSONEncoder, sort_keys=sort_keys, **kwargs)
data = data.model_dump(mode="json", exclude_unset=True)
return json.dumps(data, cls=JSONEncoder, sort_keys=sort_keys, **kwargs)


class JSONEncoder(json.JSONEncoder):
Expand Down
4 changes: 2 additions & 2 deletions src/correctionlib/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def merge(console: Console, args: argparse.Namespace) -> int:
cset.compound_corrections.append(corr2)
cset.description = "Merged from " + " ".join(args.files)
if args.format == "compact":
sys.stdout.write(cset.json())
sys.stdout.write(cset.model_dump_json())
elif args.format == "indented":
sys.stdout.write(cset.json(indent=4) + "\n")
sys.stdout.write(cset.model_dump_json(indent=4) + "\n")
elif args.format == "pretty":
from correctionlib.JSONEncoder import dumps

Expand Down
10 changes: 5 additions & 5 deletions src/correctionlib/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def read_axis(axis: "PlottableAxis", pos: int) -> Variable:
axname = getattr(
axis, "name", f"axis{pos}" if axis_names is None else axis_names[pos]
)
return Variable.parse_obj(
return Variable.model_validate(
{
"type": axtype,
"name": axname,
Expand Down Expand Up @@ -120,7 +120,7 @@ def build_data(
) -> Content:
vartype = variables[0].type
if vartype in {"string", "int"}:
return Category.parse_obj(
return Category.model_validate(
{
"nodetype": "category",
"input": variables[0].name,
Expand All @@ -142,7 +142,7 @@ def build_data(
break
i += 1
if i > 1:
return MultiBinning.parse_obj(
return MultiBinning.model_validate(
{
"nodetype": "multibinning",
"edges": [edges(ax) for ax in axes[:i]],
Expand All @@ -156,7 +156,7 @@ def build_data(
"flow": flow,
}
)
return Binning.parse_obj(
return Binning.model_validate(
{
"nodetype": "binning",
"input": variables[0].name,
Expand All @@ -171,7 +171,7 @@ def build_data(
}
)

return Correction.parse_obj(
return Correction.model_validate(
{
"version": 0,
"name": getattr(hist, "name", "unknown"),
Expand Down
6 changes: 3 additions & 3 deletions src/correctionlib/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def model_auto(data: str) -> Any:
if version == 1:
import correctionlib.schemav1

return correctionlib.schemav1.CorrectionSet.parse_obj(data)
return correctionlib.schemav1.CorrectionSet.model_validate(data)
elif version == 2:
import correctionlib.schemav2

return correctionlib.schemav2.CorrectionSet.parse_obj(data)
return correctionlib.schemav2.CorrectionSet.model_validate(data)
raise ValueError(f"Unknown CorrectionSet schema version ({version})")


Expand Down Expand Up @@ -292,7 +292,7 @@ def __init__(self, data: Any):
if isinstance(data, str):
self._data = data
else:
self._data = data.json(exclude_unset=True)
self._data = data.model_dump_json(exclude_unset=True)
self._base = correctionlib._core.CorrectionSet.from_string(self._data)

@classmethod
Expand Down
9 changes: 4 additions & 5 deletions src/correctionlib/schemav1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from typing import List, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

if sys.version_info >= (3, 8):
from typing import Literal
Expand All @@ -13,15 +13,14 @@


class Model(BaseModel):
class Config:
extra = "forbid"
model_config = ConfigDict(extra="forbid")


class Variable(Model):
name: str
type: Literal["string", "int", "real"]
"Implicitly 64 bit integer and double-precision floating point?"
description: Optional[str]
description: Optional[str] = None
# TODO: clamping behavior for out of range?


Expand Down Expand Up @@ -72,7 +71,7 @@ class Category(Model):
class Correction(Model):
name: str
"A useful name"
description: Optional[str]
description: Optional[str] = None
"Detailed description of the correction"
version: int
"Version"
Expand Down
Loading

0 comments on commit 5bfd1a3

Please sign in to comment.