From 1ba4969306a3c3880b3bd41ef1ba5dbf4b5a796e Mon Sep 17 00:00:00 2001 From: Sebastian Bodenstein Date: Wed, 12 Feb 2025 08:45:19 -0800 Subject: [PATCH] Switch Grid1D to pydantic model. PiperOrigin-RevId: 726070966 --- torax/geometry/geometry.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/torax/geometry/geometry.py b/torax/geometry/geometry.py index 41d19019..23a342f2 100644 --- a/torax/geometry/geometry.py +++ b/torax/geometry/geometry.py @@ -25,11 +25,11 @@ import jax import jax.numpy as jnp import numpy as np -from torax import jax_utils +import pydantic +from torax.torax_pydantic import torax_pydantic -@chex.dataclass(frozen=True) -class Grid1D: +class Grid1D(torax_pydantic.BaseModelFrozen): """Data structure defining a 1-D grid of cells with faces. Construct via `construct` classmethod. @@ -41,16 +41,10 @@ class Grid1D: cell_centers: Coordinates of cell centers. """ - nx: int - dx: float - face_centers: chex.Array - cell_centers: chex.Array - - def __post_init__(self): - jax_utils.assert_rank(self.nx, 0) - jax_utils.assert_rank(self.dx, 0) - jax_utils.assert_rank(self.face_centers, 1) - jax_utils.assert_rank(self.cell_centers, 1) + nx: pydantic.PositiveInt + dx: pydantic.PositiveFloat + face_centers: torax_pydantic.NumpyArray1D + cell_centers: torax_pydantic.NumpyArray1D def __eq__(self, other: Grid1D) -> bool: return ( @@ -248,7 +242,7 @@ def g1_over_vpr_face(self) -> jax.Array: @property def g1_over_vpr2_face(self) -> jax.Array: - bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:]**2 + bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:] ** 2 # Correct value on-axis is 1/rho_b**2 first_element = jnp.ones_like(self.rho_b) / self.rho_b**2 return jnp.concatenate( @@ -260,17 +254,15 @@ def z_magnetic_axis(self) -> chex.Numeric: if z_magnetic_axis is not None: return z_magnetic_axis else: - raise ValueError( - 'Geometry does not have a z magnetic axis.' - ) + raise ValueError('Geometry does not have a z magnetic axis.') def stack_geometries(geometries: Sequence[Geometry]) -> Geometry: """Batch together a sequence of geometries. Args: - geometries: A sequence of geometries to stack. The geometries must have - the same mesh, geometry type. + geometries: A sequence of geometries to stack. The geometries must have the + same mesh, geometry type. Returns: A Geometry object, where each array attribute has an additional