Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch Grid1D to pydantic model. #731

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 11 additions & 19 deletions torax/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import jax
import jax.numpy as jnp
import numpy as np
from torax import jax_utils
import pydantic
from torax.torax_pydantic import torax_pydantic


@chex.dataclass(frozen=True)
class Grid1D:
class Grid1D(torax_pydantic.BaseModelFrozen):
"""Data structure defining a 1-D grid of cells with faces.

Construct via `construct` classmethod.
Expand All @@ -41,16 +41,10 @@ class Grid1D:
cell_centers: Coordinates of cell centers.
"""

nx: int
dx: float
face_centers: chex.Array
cell_centers: chex.Array

def __post_init__(self):
jax_utils.assert_rank(self.nx, 0)
jax_utils.assert_rank(self.dx, 0)
jax_utils.assert_rank(self.face_centers, 1)
jax_utils.assert_rank(self.cell_centers, 1)
nx: pydantic.PositiveInt
dx: pydantic.PositiveFloat
face_centers: torax_pydantic.NumpyArray1D
cell_centers: torax_pydantic.NumpyArray1D

def __eq__(self, other: Grid1D) -> bool:
return (
Expand Down Expand Up @@ -248,7 +242,7 @@ def g1_over_vpr_face(self) -> jax.Array:

@property
def g1_over_vpr2_face(self) -> jax.Array:
bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:]**2
bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:] ** 2
# Correct value on-axis is 1/rho_b**2
first_element = jnp.ones_like(self.rho_b) / self.rho_b**2
return jnp.concatenate(
Expand All @@ -260,17 +254,15 @@ def z_magnetic_axis(self) -> chex.Numeric:
if z_magnetic_axis is not None:
return z_magnetic_axis
else:
raise ValueError(
'Geometry does not have a z magnetic axis.'
)
raise ValueError('Geometry does not have a z magnetic axis.')


def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
"""Batch together a sequence of geometries.

Args:
geometries: A sequence of geometries to stack. The geometries must have
the same mesh, geometry type.
geometries: A sequence of geometries to stack. The geometries must have the
same mesh, geometry type.

Returns:
A Geometry object, where each array attribute has an additional
Expand Down