Skip to content

Commit

Permalink
Begin to simplify some of the core_profile_setters code
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723939498
  • Loading branch information
tamaranorman authored and Torax team committed Feb 6, 2025
1 parent 7afe63f commit 7f68a96
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 128 deletions.
144 changes: 23 additions & 121 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,132 +282,19 @@ def get_ion_density_and_charge_states(
return ni, nimp, Zi, Zi_face, Zimp, Zimp_face


def _prescribe_currents_no_bootstrap(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
source_models: source_models_lib.SourceModels,
) -> state.Currents:
"""Creates the initial Currents without the bootstrap current.
Args:
static_runtime_params_slice: Static runtime parameters.
dynamic_runtime_params_slice: General runtime parameters at t_initial.
geo: Geometry of the tokamak.
core_profiles: Core profiles.
source_models: All TORAX source/sink functions.
Returns:
currents: Initial Currents
"""
# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style

# Calculate splitting of currents depending on input runtime params.
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot

# Set zero bootstrap current
bootstrap_profile = source_profiles_lib.BootstrapCurrentProfile.zero_profile(
geo
)

# calculate "External" current profile (e.g. ECCD)
external_current = source_models.external_current_source(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)
Iext = (
math_utils.cell_integration(external_current * geo.spr, geo) / 10**6
)
# Total Ohmic current.
Iohm = Ip - Iext

# construct prescribed current formula on grid.
jformula = (
1 - geo.rho_norm**2
) ** dynamic_runtime_params_slice.profile_conditions.nu
# calculate total and Ohmic current profiles
denom = _trapz(jformula * geo.spr, geo.rho_norm)
if dynamic_runtime_params_slice.profile_conditions.initial_j_is_total_current:
Ctot = Ip * 1e6 / denom
jtot = jformula * Ctot
johm = jtot - external_current
else:
Cohm = Iohm * 1e6 / denom
johm = jformula * Cohm
jtot = johm + external_current

jtot_hires = _get_jtot_hires(
dynamic_runtime_params_slice,
geo,
bootstrap_profile,
Iohm,
external_current=external_current,
)

currents = state.Currents(
jtot=jtot,
jtot_face=math_utils.cell_to_face(
jtot,
geo,
preserved_quantity=math_utils.IntegralPreservationQuantity.SURFACE,
),
jtot_hires=jtot_hires,
johm=johm,
external_current_source=external_current,
j_bootstrap=bootstrap_profile.j_bootstrap,
j_bootstrap_face=bootstrap_profile.j_bootstrap_face,
I_bootstrap=bootstrap_profile.I_bootstrap,
Ip_profile_face=jnp.zeros(geo.rho_face.shape), # psi not yet calculated
sigma=bootstrap_profile.sigma,
)

return currents


def _prescribe_currents_with_bootstrap(
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
def _prescribe_currents(
bootstrap_profile: source_profiles_lib.BootstrapCurrentProfile,
external_current: jax.Array,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
source_models: source_models_lib.SourceModels,
) -> state.Currents:
"""Creates the initial Currents.
Args:
static_runtime_params_slice: Static runtime parameters.
dynamic_runtime_params_slice: General runtime parameters at t_initial.
geo: Geometry of the tokamak.
core_profiles: Core profiles.
source_models: All TORAX source/sink functions. If not provided, uses the
default sources.
Returns:
currents: Plasma currents
"""
"""Creates the initial Currents from a given bootstrap profile."""

# Many variables throughout this function are capitalized based on physics
# notational conventions rather than on Google Python style
Ip = dynamic_runtime_params_slice.profile_conditions.Ip_tot

bootstrap_profile = source_models.j_bootstrap.get_bootstrap(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)
f_bootstrap = bootstrap_profile.I_bootstrap / (Ip * 1e6)

# calculate "External" current profile (e.g. ECCD)
external_current = source_models.external_current_source(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
)
Iext = (
math_utils.cell_integration(external_current * geo.spr, geo) / 10**6
)
Expand Down Expand Up @@ -653,12 +540,21 @@ def _init_psi_and_current(
isinstance(geo, circular_geometry.CircularAnalyticalGeometry)
or dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j
):
currents = _prescribe_currents_no_bootstrap(
# First calculate currents without bootstrap.
bootstrap = source_profiles_lib.BootstrapCurrentProfile.zero_profile(
geo
)
external_current = source_models.external_current_source(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
source_models=source_models,
)
currents = _prescribe_currents(
bootstrap_profile=bootstrap,
external_current=external_current,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
)
psi = _update_psi_from_j(
dynamic_runtime_params_slice,
Expand All @@ -668,12 +564,18 @@ def _init_psi_and_current(
core_profiles = dataclasses.replace(
core_profiles, currents=currents, psi=psi
)
currents = _prescribe_currents_with_bootstrap(
# Now calculate currents with bootstrap.
bootstrap_profile = source_models.j_bootstrap.get_bootstrap(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=core_profiles,
source_models=source_models,
)
currents = _prescribe_currents(
bootstrap_profile=bootstrap_profile,
external_current=external_current,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
)
psi = _update_psi_from_j(
dynamic_runtime_params_slice,
Expand Down
17 changes: 10 additions & 7 deletions torax/tests/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torax.sources import generic_current_source
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
from torax.sources import source_profiles
from torax.tests.test_lib import torax_refs


Expand Down Expand Up @@ -137,12 +138,15 @@ def test_update_psi_from_j(

# pylint: disable=protected-access
if isinstance(geo, circular_geometry.CircularAnalyticalGeometry):
currents = core_profile_setters._prescribe_currents_no_bootstrap(
static_slice,
dynamic_runtime_params_slice,
geo,
source_models=source_models,
core_profiles=initial_core_profiles,
bootstrap = source_profiles.BootstrapCurrentProfile.zero_profile(geo)
external_current = source_models.external_current_source(
geo, initial_core_profiles, dynamic_runtime_params_slice, static_slice
)
currents = core_profile_setters._prescribe_currents(
bootstrap_profile=bootstrap,
external_current=external_current,
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
geo=geo,
)
psi = core_profile_setters._update_psi_from_j(
dynamic_runtime_params_slice, geo, currents.jtot_hires
Expand All @@ -152,7 +156,6 @@ def test_update_psi_from_j(
else:
raise ValueError(f'Unknown geometry type: {geo.geometry_type}')
# pylint: enable=protected-access
print(psi)

np.testing.assert_allclose(psi, references.psi.value)

Expand Down

0 comments on commit 7f68a96

Please sign in to comment.