Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make TimeVaryingArray and TimeVaryingScalar immutable. This is with a view to make the full model config immutable, which makes it possible to safely cache computations inside the model without worrying that a user has modified the field of eg. TimeVaryingArray, which is undetectable from a higher-level without major overhead. #730

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torax/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class Grid1D:

nx: int
dx: float
face_centers: chex.Array
cell_centers: chex.Array
face_centers: np.ndarray
cell_centers: np.ndarray

def __post_init__(self):
jax_utils.assert_rank(self.nx, 0)
Expand Down
2 changes: 1 addition & 1 deletion torax/time_step_calculator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TimeStepCalculatorType(enum.Enum):
FIXED = 'fixed'


class TimeStepCalculator(torax_pydantic.BaseModelMutable):
class TimeStepCalculator(torax_pydantic.BaseModelFrozen):
"""Config for a time step calculator."""

calculator_type: TimeStepCalculatorType = TimeStepCalculatorType.CHI
Expand Down
5 changes: 2 additions & 3 deletions torax/torax_pydantic/interpolated_param_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
class TimeVaryingScalar(interpolated_param_common.TimeVaryingBase):
"""Base class for time interpolated scalar types.

All fields are frozen after initialization.

The Pydantic `.model_validate` constructor can accept a variety of input types
defined by the `TimeInterpolatedInput` type. See
https://torax.readthedocs.io/en/latest/configuration.html#time-varying-scalars
Expand Down Expand Up @@ -54,9 +56,6 @@ def _conform_data(
) -> dict[str, Any]:

if isinstance(data, dict):
# A workaround for https://github.com/pydantic/pydantic/issues/10477.
data.pop('_get_cached_interpolated_param', None)

# This is the standard constructor input. No conforming required.
if set(data.keys()).issubset(cls.model_fields.keys()):
return data # pytype: disable=bad-return-type
Expand Down
48 changes: 44 additions & 4 deletions torax/torax_pydantic/interpolated_param_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
import pydantic
from torax import interpolated_param
from torax.geometry import geometry
from torax.torax_pydantic import interpolated_param_common
from torax.torax_pydantic import model_base
import xarray as xr
Expand All @@ -28,6 +29,9 @@
class TimeVaryingArray(interpolated_param_common.TimeVaryingBase):
"""Base class for time interpolated array types.

All fields are frozen after initialization, but `rho_norm_grid` can be set
after initialization via the `set_rho_norm_grid` method.

The Pydantic `.model_validate` constructor can accept a variety of input types
defined by the `TimeRhoInterpolatedInput` type. See
https://torax.readthedocs.io/en/latest/configuration.html#time-varying-arrays
Expand All @@ -38,7 +42,9 @@ class TimeVaryingArray(interpolated_param_common.TimeVaryingBase):
`rho_norm` and `values` are 1D NumPy arrays of equal length.
rho_interpolation_mode: The interpolation mode to use for the rho axis.
time_interpolation_mode: The interpolation mode to use for the time axis.
rho_norm_grid: The rho norm grid to use for the interpolation.
rho_norm_grid: The rho norm grid to use for the interpolation. This is
generally not known at initialization time, so it is set separately via
the `set_rho_norm_grid` method.
"""

value: Mapping[float, tuple[model_base.NumpyArray1D, model_base.NumpyArray1D]]
Expand All @@ -50,16 +56,22 @@ class TimeVaryingArray(interpolated_param_common.TimeVaryingBase):
)
rho_norm_grid: model_base.NumpyArray | None = None

@functools.cached_property
def rhonorm1_defined_in_timerhoinput(self) -> bool:
"""Checks if the boundary condition at rho=1.0 is always defined."""

for _, (rho_norm, _) in self.value.items():
if 1.0 not in rho_norm:
return False
return True

@pydantic.model_validator(mode='before')
@classmethod
def _conform_data(
cls, data: interpolated_param.TimeRhoInterpolatedInput | dict[str, Any]
) -> dict[str, Any]:

if isinstance(data, dict):
# A workaround for https://github.com/pydantic/pydantic/issues/10477.
data.pop('_get_cached_interpolated_param', None)

# This is the standard constructor input. No conforming required.
if set(data.keys()).issubset(cls.model_fields.keys()):
return data
Expand Down Expand Up @@ -115,6 +127,34 @@ def _get_cached_interpolated_param(
rho_interpolation_mode=self.rho_interpolation_mode,
)

def set_rho_norm_grid(self, grid: model_base.NumpyArray | geometry.Grid1D):
"""Sets the rho_norm_grid field.

This function can only be called if the rho_norm_grid field is None.

Args:
grid: The grid to use for interpolation, either as a NumPy array or a
geometry.Grid1D object.

Raises:
RuntimeError: If the rho_norm_grid field is not None.
Returns:
No return value.
"""

if self.rho_norm_grid is not None:
raise RuntimeError(
'set_rho_norm_grid can only be called when the rho_norm_grid field is'
' None.'
)

if isinstance(grid, geometry.Grid1D):
grid = grid.cell_centers

# Bypass the pydantic validator that enforces field immutability by
# directly modifying the underlying model dictionary.
self.__dict__['rho_norm_grid'] = grid


def _load_from_primitives(
primitive_values: (
Expand Down
12 changes: 1 addition & 11 deletions torax/torax_pydantic/interpolated_param_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
import abc
import functools
import chex
import pydantic
from torax import interpolated_param
from torax.torax_pydantic import model_base
from typing_extensions import Self


class TimeVaryingBase(model_base.BaseModelMutable):
class TimeVaryingBase(model_base.BaseModelFrozen):
"""Base class for time varying interpolated parameters."""

def get_value(self, x: chex.Numeric) -> chex.Array:
Expand Down Expand Up @@ -58,11 +56,3 @@ def _get_cached_interpolated_param(
):
"""Returns the value of this parameter interpolated at x=time."""
...

@pydantic.model_validator(mode='after')
def clear_cached_property(self) -> Self:
try:
del self._get_cached_interpolated_param
except AttributeError:
pass
return self
33 changes: 22 additions & 11 deletions torax/torax_pydantic/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax
import numpy as np
import pydantic
from torax.geometry import geometry
from typing_extensions import Self

DataTypes: TypeAlias = float | int | bool
Expand Down Expand Up @@ -70,21 +71,21 @@ def _numpy_array_is_rank_1(x: np.ndarray) -> np.ndarray:
]


class BaseModelMutable(pydantic.BaseModel):
class BaseModelFrozen(pydantic.BaseModel):
"""Base config class. Any custom config classes should inherit from this.

No model fields are allowed to be assigned to after construction.

See https://docs.pydantic.dev/latest/ for documentation on pydantic.

This class is compatible with JAX, so can be used as an argument to a JITted
function.
"""

model_config = pydantic.ConfigDict(
frozen=False,
frozen=True,
# Do not allow attributes not defined in pydantic model.
extra='forbid',
# Re-run validation if the model is updated.
validate_assignment=True,
arbitrary_types_allowed=True,
)

Expand Down Expand Up @@ -121,14 +122,24 @@ def from_dict(cls: type[Self], cfg: Mapping[str, Any]) -> Self:
def to_dict(self) -> dict[str, Any]:
return self.model_dump()

def set_rho_norm_grid(self, grid: NumpyArray | geometry.Grid1D):
"""Sets the rho_norm_grid field in all TimeVaryingArray fields.

class BaseModelFrozen(BaseModelMutable, frozen=True):
"""Base config with frozen fields.
This will set the grid to all sub-models as well.

See https://docs.pydantic.dev/latest/ for documentation on pydantic.
This function can only be called if the rho_norm_grid field is None.

This class is compatible with JAX, so can be used as an argument to a JITted
function.
"""
Args:
grid: The grid to use for interpolation, either as a NumPy array or a
geometry.Grid1D object.

Raises:
RuntimeError: If the rho_norm_grid field is not None.
Returns:
No return value.
"""

...
for name in self.model_fields.keys():
attr = getattr(self, name)
if hasattr(attr, 'set_rho_norm_grid'):
attr.set_rho_norm_grid(grid)
2 changes: 1 addition & 1 deletion torax/torax_pydantic/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torax.torax_pydantic import model_base


class ToraxConfig(model_base.BaseModelMutable):
class ToraxConfig(model_base.BaseModelFrozen):
"""Base config class for Torax.

Attributes:
Expand Down
2 changes: 1 addition & 1 deletion torax/torax_pydantic/tests/interpolated_param_1d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_time_varying_model_basic(

default_value = 1.53

class TestModel(torax_pydantic.BaseModelMutable):
class TestModel(torax_pydantic.BaseModelFrozen):
a: interpolated_param_1d.TimeVaryingScalar
b: interpolated_param_1d.TimeVaryingScalar = pydantic.Field(
default_factory=lambda: default_value, validate_default=True
Expand Down
30 changes: 20 additions & 10 deletions torax/torax_pydantic/tests/interpolated_param_2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax.geometry import circular_geometry
from torax.torax_pydantic import interpolated_param_2d
import xarray as xr

Expand Down Expand Up @@ -214,7 +215,7 @@ def test_interpolated_var_time_rho_parses_inputs_correctly(
interpolated = interpolated_param_2d.TimeVaryingArray.model_validate(
time_rho_interpolated_input
)
interpolated.rho_norm_grid = rho_norm
interpolated.set_rho_norm_grid(rho_norm)

np.testing.assert_allclose(
interpolated.get_value(x=time),
Expand All @@ -225,17 +226,26 @@ def test_interpolated_var_time_rho_parses_inputs_correctly(

def test_mutation_behavior(self):
v1 = 1.0
v2 = 2.0
interpolated = interpolated_param_2d.TimeVaryingArray.model_validate(v1)
interpolated.rho_norm_grid = np.array([0.0, 0.5, 1.0])
out1 = interpolated.get_value(x=0.0)
self.assertEqual(out1.tolist(), [v1, v1, v1])

# Modifying the value should change the output of get_value. This tests
# that caching is working correctly.
interpolated.value = {0.0: (np.array([0.0]), np.array([v2]))}
out2 = interpolated.get_value(x=0.0)
self.assertEqual(out2.tolist(), [v2, v2, v2])
# Directly setting the grid is banned due to immutability.
with self.assertRaises(ValueError):
interpolated.rho_norm_grid = np.array([0.0, 0.5, 1.0])

# The grid is not set, so we should raise an error as there is not enough
# information to interpolate.
with self.assertRaises(ValueError):
interpolated.get_value(x=0.0)

geo = circular_geometry.build_circular_geometry()
interpolated.set_rho_norm_grid(geo.torax_mesh)

# Setting the grid twice should raise an error.
with self.assertRaises(RuntimeError):
interpolated.set_rho_norm_grid(geo.torax_mesh)

out1 = interpolated.get_value(x=0.0)
self.assertEqual(out1.tolist(), [v1] * len(interpolated.rho_norm_grid))


if __name__ == '__main__':
Expand Down
73 changes: 35 additions & 38 deletions torax/torax_pydantic/tests/model_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

"""Unit tests for the `torax.torax_pydantic.model_base` module."""

import functools
from absl.testing import absltest
from absl.testing import parameterized
import jax
import numpy as np
import pydantic
from torax.torax_pydantic import interpolated_param_2d
from torax.torax_pydantic import model_base


Expand Down Expand Up @@ -82,44 +82,11 @@ class TestModel(model_base.BaseModelFrozen):
with self.assertRaises(ValueError):
m.x = 2.0

def test_model_base(self):
def test_model_base_map_pytree(self):

class Test(model_base.BaseModelMutable, validate_assignment=True):
name: str

@functools.cached_property
def computed(self):
return self.name + '_test' # pytype: disable=attribute-error

@pydantic.model_validator(mode='after')
def validate(self):
if hasattr(self, 'computed'):
del self.computed
return self

m = Test(name='test_string')
self.assertEqual(m.computed, 'test_string_test')

with self.subTest('field_is_mutable'):
m.name = 'new_test_string'

with self.subTest('after_model_validator_is_called_on_update'):
self.assertEqual(m.computed, 'new_test_string_test')

@parameterized.parameters(True, False)
def test_model_base_map_pytree(self, frozen: bool):

if frozen:

class TestModel(model_base.BaseModelFrozen):
x: float
y: float

else:

class TestModel(model_base.BaseModelMutable):
x: float
y: float
class TestModel(model_base.BaseModelFrozen):
x: float
y: float

m = TestModel(x=2.0, y=4.0)
m2 = jax.tree_util.tree_map(lambda x: x**2, m)
Expand All @@ -134,6 +101,36 @@ def f(data):
with self.subTest('jit_works'):
self.assertEqual(f(m), m.x * m.y)

def test_model_set_grid(self):

class LowerModel(model_base.BaseModelFrozen):
x: float
y: interpolated_param_2d.TimeVaryingArray

class TestModel(model_base.BaseModelFrozen):
x: int
y: interpolated_param_2d.TimeVaryingArray
z: LowerModel # pytype: disable=invalid-annotation

m = TestModel(
x=1,
y=interpolated_param_2d.TimeVaryingArray.model_validate(1.0),
z=LowerModel(
x=1.0, y=interpolated_param_2d.TimeVaryingArray.model_validate(2.0)
),
)

grid = np.array([1.0, 2.0, 3.0])
m.set_rho_norm_grid(grid)
# This test ensures that the grid is correctly set, and that no copies of
# the grid are made.
self.assertIs(m.y.rho_norm_grid, grid)
self.assertIs(m.z.y.rho_norm_grid, grid)

with self.subTest('cannot_set_grid_twice'):
with self.assertRaises(RuntimeError):
m.y.set_rho_norm_grid(grid)


if __name__ == '__main__':
absltest.main()
Loading