Skip to content

Commit

Permalink
Change get_value to return a tuple of values, reducing stacking and u…
Browse files Browse the repository at this point in the history
…nstacking and simplifying code

PiperOrigin-RevId: 724276085
  • Loading branch information
tamaranorman authored and Torax team committed Feb 12, 2025
1 parent ec8513d commit 09933bb
Show file tree
Hide file tree
Showing 23 changed files with 114 additions and 232 deletions.
2 changes: 1 addition & 1 deletion torax/sources/bootstrap_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def get_value(
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
calculated_source_profiles: source_profiles.SourceProfiles | None,
) -> chex.Array:
) -> tuple[chex.Array, ...]:
raise NotImplementedError('Call `get_bootstrap` instead.')

def get_source_profile_for_affected_core_profile(
Expand Down
4 changes: 2 additions & 2 deletions torax/sources/bremsstrahlung_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def bremsstrahlung_model_func(
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_model_func: source_models.SourceModels | None,
) -> jax.Array:
) -> tuple[chex.Array, ...]:
"""Model function for the Bremsstrahlung heat sink."""
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
Expand All @@ -143,7 +143,7 @@ def bremsstrahlung_model_func(
use_relativistic_correction=dynamic_source_runtime_params.use_relativistic_correction,
)
# As a sink, the power is negative.
return -1.0 * P_brem_profile
return (-1.0 * P_brem_profile,)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
Expand Down
4 changes: 2 additions & 2 deletions torax/sources/cyclotron_radiation_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def cyclotron_radiation_albajar(
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels,
) -> array_typing.ArrayFloat:
) -> tuple[array_typing.ArrayFloat, ...]:
"""Calculates the cyclotron radiation heat sink contribution to the electron heat equation.
Total cyclotron radiation is from:
Expand Down Expand Up @@ -402,7 +402,7 @@ def cyclotron_radiation_albajar(
rescaling_factor = P_cycl_total / denom
Q_cycl = Q_cycl_shape * rescaling_factor

return -Q_cycl
return (-Q_cycl,)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
Expand Down
5 changes: 2 additions & 3 deletions torax/sources/electron_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import ClassVar

import chex
import jax
import jax.numpy as jnp
from torax import array_typing
from torax import constants
Expand Down Expand Up @@ -112,7 +111,7 @@ def calc_heating_and_current(
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
) -> tuple[chex.Array, ...]:
"""Model function for the electron-cyclotron source.
Based on Lin-Liu, Y. R., Chan, V. S., & Prater, R. (2003).
Expand Down Expand Up @@ -178,7 +177,7 @@ def calc_heating_and_current(
j_ec_dot_B = jnp.exp(log_j_ec_dot_B)
# pylint: enable=invalid-name

return jnp.stack([ec_power_density, j_ec_dot_B])
return ec_power_density, j_ec_dot_B


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
Expand Down
37 changes: 12 additions & 25 deletions torax/sources/electron_density_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import ClassVar

import chex
import jax
from torax import array_typing
from torax import interpolated_param
from torax import state
Expand Down Expand Up @@ -78,32 +77,28 @@ class DynamicGasPuffRuntimeParams(runtime_params_lib.DynamicRuntimeParams):

# Default formula: exponential with nref normalization.
def calc_puff_source(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
unused_state: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
) -> tuple[chex.Array, ...]:
"""Calculates external source term for n from puffs."""
del (
unused_source_models,
static_runtime_params_slice,
) # Unused.
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
]
assert isinstance(dynamic_source_runtime_params, DynamicGasPuffRuntimeParams)
return formulas.exponential_profile(
return (formulas.exponential_profile(
c1=1.0,
c2=dynamic_source_runtime_params.puff_decay_length,
total=(
dynamic_source_runtime_params.S_puff_tot
/ dynamic_runtime_params_slice.numerics.nref
),
geo=geo,
)
),)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
Expand Down Expand Up @@ -175,32 +170,28 @@ class DynamicParticleRuntimeParams(runtime_params_lib.DynamicRuntimeParams):


def calc_generic_particle_source(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
unused_state: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
) -> tuple[chex.Array, ...]:
"""Calculates external source term for n from SBI."""
del (
unused_source_models,
static_runtime_params_slice,
) # Unused.
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
]
assert isinstance(dynamic_source_runtime_params, DynamicParticleRuntimeParams)
return formulas.gaussian_profile(
return (formulas.gaussian_profile(
c1=dynamic_source_runtime_params.deposition_location,
c2=dynamic_source_runtime_params.particle_width,
total=(
dynamic_source_runtime_params.S_tot
/ dynamic_runtime_params_slice.numerics.nref
),
geo=geo,
)
),)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
Expand Down Expand Up @@ -265,32 +256,28 @@ class DynamicPelletRuntimeParams(runtime_params_lib.DynamicRuntimeParams):


def calc_pellet_source(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
unused_static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
source_name: str,
unused_state: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
) -> tuple[chex.Array, ...]:
"""Calculates external source term for n from pellets."""
del (
unused_source_models,
static_runtime_params_slice,
) # Unused.
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
]
assert isinstance(dynamic_source_runtime_params, DynamicPelletRuntimeParams)
return formulas.gaussian_profile(
return (formulas.gaussian_profile(
c1=dynamic_source_runtime_params.pellet_deposition_location,
c2=dynamic_source_runtime_params.pellet_width,
total=(
dynamic_source_runtime_params.S_pellet_tot
/ dynamic_runtime_params_slice.numerics.nref
),
geo=geo,
)
),)


@dataclasses.dataclass(kw_only=True, frozen=True, eq=True)
Expand Down
5 changes: 3 additions & 2 deletions torax/sources/fusion_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import dataclasses
from typing import ClassVar, Optional

import chex
import jax
from jax import numpy as jnp
from torax import constants
Expand Down Expand Up @@ -149,7 +150,7 @@ def fusion_heat_model_func(
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: Optional['source_models.SourceModels'],
) -> jax.Array:
) -> tuple[chex.Array, ...]:
"""Model function for fusion heating."""
# pytype: enable=name-error
# pylint: disable=invalid-name
Expand All @@ -159,7 +160,7 @@ def fusion_heat_model_func(
static_runtime_params_slice,
dynamic_runtime_params_slice,
)
return jnp.stack((Pfus_i, Pfus_e))
return (Pfus_i, Pfus_e)
# pylint: enable=invalid-name


Expand Down
5 changes: 2 additions & 3 deletions torax/sources/generic_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import ClassVar, Optional

import chex
import jax
from jax import numpy as jnp
from torax import array_typing
from torax import interpolated_param
Expand Down Expand Up @@ -109,7 +108,7 @@ def calculate_generic_current(
unused_state: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: Optional['source_models.SourceModels'] = None,
) -> jax.Array:
) -> tuple[chex.Array, ...]:
"""Calculates the external current density profiles on the cell grid."""
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
Expand All @@ -135,7 +134,7 @@ def calculate_generic_current(
generic_current_profile = (
Cext * generic_current_form
)
return generic_current_profile
return (generic_current_profile,)


def _calculate_Iext(
Expand Down
8 changes: 3 additions & 5 deletions torax/sources/generic_ion_el_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
from typing import ClassVar, Optional

import chex
import jax
from jax import numpy as jnp
from torax import array_typing
from torax import interpolated_param
from torax import state
Expand Down Expand Up @@ -89,7 +87,7 @@ def calc_generic_heat_source(
w: float,
Ptot: float,
el_heat_fraction: float,
) -> tuple[jax.Array, jax.Array]:
) -> tuple[chex.Array, chex.Array]:
"""Computes ion/electron heat source terms.
Flexible prescribed heat source term.
Expand Down Expand Up @@ -122,7 +120,7 @@ def default_formula(
unused_core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: Optional['source_models.SourceModels'],
) -> jax.Array:
) -> tuple[chex.Array, ...]:
"""Returns the default formula-based ion/electron heat source profile."""
# pytype: enable=name-error
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
Expand All @@ -136,7 +134,7 @@ def default_formula(
dynamic_source_runtime_params.Ptot,
dynamic_source_runtime_params.el_heat_fraction,
)
return jnp.stack([ion, el])
return (ion, el)


# pylint: enable=invalid-name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def radially_constant_fraction_of_Pin( # pylint: disable=invalid-name
core_profiles: state.CoreProfiles,
calculated_source_profiles: source_profiles_lib.SourceProfiles | None,
source_models: source_models_lib.SourceModels,
) -> jax.Array:
) -> tuple[chex.Array, ...]:
"""Model function for radiation heat sink from impurities.
This model represents a sink in the temp_el equation, whose value is a fixed %
Expand Down Expand Up @@ -105,8 +105,7 @@ def get_heat_source_profile(source: source_lib.Source) -> jax.Array:
-dynamic_source_runtime_params.fraction_of_total_power_density
* Ptot_in
/ Vtot
* jnp.ones_like(geo.rho)
)
* jnp.ones_like(geo.rho),)


@dataclasses.dataclass(kw_only=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Final, Mapping, Sequence
import chex
import immutabledict
import jax
import jax.numpy as jnp
import numpy as np
from torax import array_typing
Expand Down Expand Up @@ -193,7 +192,7 @@ def impurity_radiation_mavrin_fit(
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models_lib.SourceModels | None = None,
) -> jax.Array:
) -> tuple[chex.Array, ...]:
"""Model function for impurity radiation heat sink."""
del (geo, unused_source_models)
effective_LZ = calculate_total_impurity_radiation(
Expand All @@ -215,7 +214,7 @@ def impurity_radiation_mavrin_fit(

# The impurity radiation heat sink is a negative source, so we return a
# negative profile.
return -radiation_profile
return (-radiation_profile,)


@dataclasses.dataclass(kw_only=True)
Expand Down
4 changes: 2 additions & 2 deletions torax/sources/ion_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def icrh_model_func(
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
unused_source_models: source_models.SourceModels | None,
toric_nn: ToricNNWrapper,
) -> jax.Array:
) -> tuple[chex.Array, ...]:
"""Compute ion/electron heat source terms."""
dynamic_source_runtime_params = dynamic_runtime_params_slice.sources[
source_name
Expand Down Expand Up @@ -483,7 +483,7 @@ def icrh_model_func(
# Assume that all the power from the tritium power profile goes to ions.
source_ion += power_deposition_2T * dynamic_source_runtime_params.Ptot

return jnp.stack([source_ion, source_el])
return (source_ion, source_el)
# pylint: enable=invalid-name


Expand Down
5 changes: 3 additions & 2 deletions torax/sources/ohmic_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import functools
from typing import ClassVar

import chex
import jax
import jax.numpy as jnp
from torax import constants
Expand Down Expand Up @@ -172,7 +173,7 @@ def ohmic_model_func(
core_profiles: state.CoreProfiles,
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
source_models: source_models_lib.SourceModels,
) -> jax.Array:
) -> tuple[chex.Array, ...]:
"""Returns the Ohmic source for electron heat equation."""
if source_models is None:
raise TypeError('source_models is a required argument for ohmic_model_func')
Expand All @@ -191,7 +192,7 @@ def ohmic_model_func(
)

pohm = jtot * psidot / (2 * jnp.pi * geo.Rmaj)
return pohm
return (pohm,)


@dataclasses.dataclass
Expand Down
4 changes: 2 additions & 2 deletions torax/sources/qei_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ def get_value(
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
) -> chex.Array:
) -> tuple[chex.Array, ...]:
raise NotImplementedError('Call get_qei() instead.')

def get_source_profile_for_affected_core_profile(
self,
profile: chex.ArrayTree,
profile: tuple[chex.Array, ...],
affected_mesh_state: int,
geo: geometry.Geometry,
) -> jax.Array:
Expand Down
Loading

0 comments on commit 09933bb

Please sign in to comment.