Skip to content

Commit

Permalink
Switch Grid1D to pydantic model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726028947
  • Loading branch information
sbodenstein authored and Torax team committed Feb 12, 2025
1 parent 7904949 commit 8ae13a0
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions torax/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import jax
import jax.numpy as jnp
import numpy as np
from torax import jax_utils
import pydantic
from torax.torax_pydantic import torax_pydantic


@chex.dataclass(frozen=True)
class Grid1D:
class Grid1D(torax_pydantic.BaseModelFrozen):
"""Data structure defining a 1-D grid of cells with faces.
Construct via `construct` classmethod.
Expand All @@ -41,16 +41,10 @@ class Grid1D:
cell_centers: Coordinates of cell centers.
"""

nx: int
dx: float
face_centers: chex.Array
cell_centers: chex.Array

def __post_init__(self):
jax_utils.assert_rank(self.nx, 0)
jax_utils.assert_rank(self.dx, 0)
jax_utils.assert_rank(self.face_centers, 1)
jax_utils.assert_rank(self.cell_centers, 1)
nx: pydantic.PositiveInt
dx: pydantic.PositiveFloat
face_centers: torax_pydantic.NumpyArray1D
cell_centers: torax_pydantic.NumpyArray1D

def __eq__(self, other: Grid1D) -> bool:
return (
Expand Down Expand Up @@ -248,7 +242,7 @@ def g1_over_vpr_face(self) -> jax.Array:

@property
def g1_over_vpr2_face(self) -> jax.Array:
bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:]**2
bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:] ** 2
# Correct value on-axis is 1/rho_b**2
first_element = jnp.ones_like(self.rho_b) / self.rho_b**2
return jnp.concatenate(
Expand All @@ -260,17 +254,15 @@ def z_magnetic_axis(self) -> chex.Numeric:
if z_magnetic_axis is not None:
return z_magnetic_axis
else:
raise ValueError(
'Geometry does not have a z magnetic axis.'
)
raise ValueError('Geometry does not have a z magnetic axis.')


def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
"""Batch together a sequence of geometries.
Args:
geometries: A sequence of geometries to stack. The geometries must have
the same mesh, geometry type.
geometries: A sequence of geometries to stack. The geometries must have the
same mesh, geometry type.
Returns:
A Geometry object, where each array attribute has an additional
Expand Down

0 comments on commit 8ae13a0

Please sign in to comment.