Skip to content

Commit

Permalink
Extend geometry outputs to include variables defined as Geometry prop…
Browse files Browse the repository at this point in the history
…erties.

Drive-bys:
1. geometry attribute in output.StateHistory is now a Geometry class or child class, with stacked arrays, similar to core_profiles, core_sources, and core_transport.
2. drho_norm is now a property (just depends on the static mesh)
3. Extended the tests

All benchmark cases regenerated with the new outputs. Tests passed before the regeneration.

PiperOrigin-RevId: 724269913
  • Loading branch information
jcitrin authored and Torax team committed Feb 7, 2025
1 parent 881519b commit 1264dd1
Show file tree
Hide file tree
Showing 60 changed files with 165 additions and 88 deletions.
1 change: 0 additions & 1 deletion torax/geometry/circular_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def build_circular_geometry(
return CircularAnalyticalGeometry(
# Set the standard geometry params.
geometry_type=geometry.GeometryType.CIRCULAR.value,
drho_norm=np.asarray(drho_norm),
torax_mesh=mesh,
Phi=Phi,
Phi_face=Phi_face,
Expand Down
101 changes: 56 additions & 45 deletions torax/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@

from __future__ import annotations

from collections.abc import Mapping, Sequence
from collections.abc import Sequence
import dataclasses
import enum

import chex
import jax
import jax.numpy as jnp
import numpy as np
from torax import array_typing
from torax import jax_utils


Expand Down Expand Up @@ -125,12 +124,14 @@ class Geometry:
Most users should default to using the StandardGeometry class, whether the
source of their geometry comes from CHEASE, MEQ, etc.
Properties work for both 1D radial arrays and 2D stacked arrays where the
leading dimension is time.
"""

# TODO(b/356356966): extend documentation to define what each attribute is.
geometry_type: int
torax_mesh: Grid1D
drho_norm: array_typing.ArrayFloat
Phi: chex.Array
Phi_face: chex.Array
Rmaj: chex.Array
Expand Down Expand Up @@ -182,13 +183,17 @@ def rho_norm(self) -> chex.Array:
def rho_face_norm(self) -> chex.Array:
return self.torax_mesh.face_centers

@property
def drho_norm(self) -> chex.Array:
return jnp.array(self.torax_mesh.dx)

@property
def rho_face(self) -> chex.Array:
return self.rho_face_norm * self.rho_b
return self.rho_face_norm * jnp.expand_dims(self.rho_b, axis=-1)

@property
def rho(self) -> chex.Array:
return self.rho_norm * self.rho_b
return self.rho_norm * jnp.expand_dims(self.rho_b, axis=-1)

@property
def rmid(self) -> chex.Array:
Expand All @@ -210,7 +215,7 @@ def rho_b(self) -> chex.Array:
@property
def Phib(self) -> chex.Array:
"""Toroidal flux at boundary (LCFS)."""
return self.Phi_face[-1]
return self.Phi_face[..., -1]

@property
def g1_over_vpr(self) -> chex.Array:
Expand All @@ -222,24 +227,33 @@ def g1_over_vpr2(self) -> chex.Array:

@property
def g0_over_vpr_face(self) -> jax.Array:
return jnp.concatenate((
jnp.ones(1) / self.rho_b, # correct value is 1/rho_b on-axis
self.g0_face[1:] / self.vpr_face[1:], # avoid div by zero on-axis
))
# Calculate the bulk of the array (excluding the first element)
# to avoid division by zero.
bulk = self.g0_face[..., 1:] / self.vpr_face[..., 1:]
# Correct value on-axis is 1/rho_b
first_element = jnp.ones_like(self.rho_b) / self.rho_b
# Concatenate to handle both 1D (no leading dim) and 2D cases
return jnp.concatenate(
[jnp.expand_dims(first_element, axis=-1), bulk], axis=-1
)

@property
def g1_over_vpr_face(self) -> jax.Array:
return jnp.concatenate((
jnp.zeros(1), # correct value is zero on-axis
self.g1_face[1:] / self.vpr_face[1:], # avoid div by zero on-axis
))
bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:]
# Correct value on-axis is 0
first_element = jnp.zeros_like(self.rho_b)
return jnp.concatenate(
[jnp.expand_dims(first_element, axis=-1), bulk], axis=-1
)

@property
def g1_over_vpr2_face(self) -> jax.Array:
return jnp.concatenate((
jnp.ones(1) / self.rho_b**2, # correct value is 1/rho_b**2 on-axis
self.g1_face[1:] / self.vpr_face[1:] ** 2, # avoid div by zero on-axis
))
bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:]**2
# Correct value on-axis is 1/rho_b**2
first_element = jnp.ones_like(self.rho_b) / self.rho_b**2
return jnp.concatenate(
[jnp.expand_dims(first_element, axis=-1), bulk], axis=-1
)

def z_magnetic_axis(self) -> chex.Numeric:
z_magnetic_axis = self._z_magnetic_axis
Expand All @@ -251,43 +265,40 @@ def z_magnetic_axis(self) -> chex.Numeric:
)


def stack_geometries(
geometries: Sequence[Geometry],
) -> Mapping[str, chex.Array]:
def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
"""Batch together a sequence of geometries.
The batched geometries are returned as a dictionary of arrays. Any fields
that are not arrays are not included in the dictionary and all properties
are excluded as well.
Args:
geometries: A sequence of geometries to stack. The geometries must have the
same mesh, geometry type, and drho_norm.
geometries: A sequence of geometries to stack. The geometries must have
the same mesh, geometry type.
Returns:
A dictionary of arrays, where each array has the same shape as the
corresponding attribute in the input geometries, but with an additional
leading axis (e.g. for the time dimension).
A Geometry object, where each array attribute has an additional
leading axis (e.g. for the time dimension) compared to each Geometry in
the input sequence.
"""
if not geometries:
raise ValueError('No geometries provided.')
# Stack the geometries.
torax_mesh = geometries[0].torax_mesh
geometry_type = geometries[0].geometry_type
drho_norm = geometries[0].drho_norm
# Ensure that all geometries have same mesh and are of same type.
first_geo = geometries[0]
torax_mesh = first_geo.torax_mesh
geometry_type = first_geo.geometry_type
for geometry in geometries[1:]:
if geometry.torax_mesh != torax_mesh:
raise ValueError('All geometries must have the same mesh.')
if geometry.geometry_type != geometry_type:
raise ValueError('All geometries must have the same geometry type.')
if geometry.drho_norm != drho_norm:
raise ValueError('All geometries must have the same drho_norm.')
array_fields = [
attr
for attr, val in dataclasses.asdict(geometries[0]).items()
if isinstance(val, chex.Array)
]
return {
attr: jnp.stack([getattr(geo, attr) for geo in geometries])
for attr in array_fields
}

stacked_data = {}
for field in dataclasses.fields(first_geo):
field_name = field.name
field_value = getattr(first_geo, field_name)
# Stack stackable fields. Save first geo's value for non-stackable fields.
if isinstance(field_value, chex.Array):
field_values = [getattr(geo, field_name) for geo in geometries]
stacked_data[field_name] = jnp.stack(field_values)
else:
stacked_data[field_name] = field_value
# Create a new object with the stacked data with the same class (i.e.
# could be child classes of Geometry)
return first_geo.__class__(**stacked_data)
7 changes: 3 additions & 4 deletions torax/geometry/standard_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ def from_chease(
return cls(
geometry_type=geometry.GeometryType.CHEASE,
Ip_from_parameters=Ip_from_parameters,
Rmaj=Rmaj,
Rmin=Rmin,
B=B0,
Rmaj=np.array(Rmaj),
Rmin=np.array(Rmin),
B=np.array(B0),
psi=psi,
Ip_profile=Ip_chease,
Phi=Phi,
Expand Down Expand Up @@ -1008,7 +1008,6 @@ def build_standard_geometry(

return StandardGeometry(
geometry_type=intermediate.geometry_type.value,
drho_norm=np.asarray(drho_norm),
torax_mesh=mesh,
Phi=Phi,
Phi_face=Phi_face,
Expand Down
76 changes: 64 additions & 12 deletions torax/geometry/tests/geometry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,70 @@ def test_none_z_magnetic_axis_raises_an_error(self):
_ = foo()

def test_stack_geometries(self):
"""Test that geometries can be stacked."""
geo_0 = circular_geometry.build_circular_geometry(Rmaj=1, B0=5.3)
geo_1 = circular_geometry.build_circular_geometry(Rmaj=2, B0=3.7)
stacked_geometries = geometry.stack_geometries([geo_0, geo_1])
self.assertNotIn('geometry_type', stacked_geometries.keys())
self.assertNotIn('torax_mesh', stacked_geometries.keys())

for key, value in stacked_geometries.items():
self.assertEqual(value.shape, (2,) + getattr(geo_0, key).shape)

np.testing.assert_allclose(stacked_geometries['Rmaj'], np.array([1, 2]))
np.testing.assert_allclose(stacked_geometries['B0'], np.array([5.3, 3.7]))
"""Test for stack_geometries."""
# Create a few different geometries
geo0 = circular_geometry.build_circular_geometry(Rmaj=1.0, B0=2.0, n_rho=10)
geo1 = circular_geometry.build_circular_geometry(Rmaj=1.5, B0=2.5, n_rho=10)
geo2 = circular_geometry.build_circular_geometry(Rmaj=2.0, B0=3.0, n_rho=10)

# Stack them
stacked_geo = geometry.stack_geometries([geo0, geo1, geo2])

# Check that the stacked geometry has the correct type and mesh
self.assertEqual(stacked_geo.geometry_type, geo0.geometry_type)
self.assertEqual(stacked_geo.torax_mesh, geo0.torax_mesh)

# Check some specific stacked values
np.testing.assert_allclose(stacked_geo.Rmaj, np.array([1.0, 1.5, 2.0]))
np.testing.assert_allclose(stacked_geo.B0, np.array([2.0, 2.5, 3.0]))
np.testing.assert_allclose(
stacked_geo.Phi_face[:, -1],
np.array([geo0.Phi_face[-1], geo1.Phi_face[-1], geo2.Phi_face[-1]]),
)

# Check stacking of derived properties
np.testing.assert_allclose(
stacked_geo.rho_b, np.array([geo0.rho_b, geo1.rho_b, geo2.rho_b])
)

# Check a property that depends on a stacked property (rho depends on rho_b)
np.testing.assert_allclose(
stacked_geo.rho,
np.array([
stacked_geo.rho_norm * geo0.rho_b,
stacked_geo.rho_norm * geo1.rho_b,
stacked_geo.rho_norm * geo2.rho_b,
]),
)

# Check properties with special handling for on-axis values.
np.testing.assert_allclose(
stacked_geo.g0_over_vpr_face[:, 0], 1 / stacked_geo.rho_b
)
np.testing.assert_allclose(
stacked_geo.g1_over_vpr2_face[:, 0], 1 / stacked_geo.rho_b**2
)

# Test stacking with an empty list (should raise ValueError)
with self.assertRaisesRegex(ValueError, 'No geometries provided.'):
geometry.stack_geometries([])

# Test stacking with geometries of different mesh sizes
# (should raise ValueError)
geo_diff_mesh = circular_geometry.build_circular_geometry(
Rmaj=1.0, B0=2.0, n_rho=20
) # Different n_rho
with self.assertRaisesRegex(
ValueError, 'All geometries must have the same mesh.'
):
geometry.stack_geometries([geo0, geo_diff_mesh])

# Test with geometries that has a different geometry type
geo_diff_mesh = dataclasses.replace(geo1, geometry_type=3)
with self.assertRaisesRegex(
ValueError, 'All geometries must have the same geometry type'
):
geometry.stack_geometries([geo0, geo_diff_mesh])


def _pint_face_to_cell(n_rho, face):
Expand Down
15 changes: 0 additions & 15 deletions torax/geometry/tests/standard_geometry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import jax
import numpy as np
from torax.config import build_sim
from torax.geometry import circular_geometry
from torax.geometry import geometry
from torax.geometry import geometry_loader
from torax.geometry import standard_geometry
Expand Down Expand Up @@ -151,20 +150,6 @@ def test_validate_fbt_data_incorrect_L_pQ_shape(self):
with self.assertRaisesRegex(ValueError, 'Incorrect shape'):
standard_geometry._validate_fbt_data(LY, L)

def test_stack_geometries(self):
"""Test that geometries can be stacked."""
geo_0 = circular_geometry.build_circular_geometry(Rmaj=1, B0=5.3)
geo_1 = circular_geometry.build_circular_geometry(Rmaj=2, B0=3.7)
stacked_geometries = geometry.stack_geometries([geo_0, geo_1])
self.assertNotIn('geometry_type', stacked_geometries.keys())
self.assertNotIn('torax_mesh', stacked_geometries.keys())

for key, value in stacked_geometries.items():
self.assertEqual(value.shape, (2,) + getattr(geo_0, key).shape)

np.testing.assert_allclose(stacked_geometries['Rmaj'], np.array([1, 2]))
np.testing.assert_allclose(stacked_geometries['B0'], np.array([5.3, 3.7]))


def _get_example_L_LY_data(len_psinorm: int, len_times: int):
LY = {
Expand Down
Loading

0 comments on commit 1264dd1

Please sign in to comment.