17
18
19
20
@@ -1890,14 +1889,7 @@
54
55
56
-57
-58
-59
-60
-61
-62
-63
-64 | def step_voltage_explicit(
+57
| def step_voltage_explicit(
voltages: jnp.ndarray,
voltage_terms: jnp.ndarray,
constant_terms: jnp.ndarray,
@@ -1906,18 +1898,13 @@
sinks: jnp.ndarray,
sources: jnp.ndarray,
types: jnp.ndarray,
- masked_node_inds: jnp.ndarray,
nseg_per_branch: jnp.ndarray,
- nseg: int,
par_inds: jnp.ndarray,
child_inds: jnp.ndarray,
nbranches: int,
solver: str,
delta_t: float,
- children_in_level: List[jnp.ndarray],
- parents_in_level: List[jnp.ndarray],
- root_inds: jnp.ndarray,
- branchpoint_group_inds: jnp.ndarray,
+ idx: JaxleySolveIndexer,
debug_states,
) -> jnp.ndarray:
"""Solve one timestep of branched nerve equations with explicit (forward) Euler."""
@@ -1938,10 +1925,7 @@
nbranches,
solver,
delta_t,
- children_in_level,
- parents_in_level,
- root_inds,
- branchpoint_group_inds,
+ idx,
debug_states,
)
new_voltates = voltages + delta_t * update
@@ -1956,7 +1940,7 @@
- step_voltage_implicit_with_jaxley_spsolve(voltages, voltage_terms, constant_terms, axial_conductances, internal_node_inds, sinks, sources, types, masked_node_inds, nseg_per_branch, nseg, par_inds, child_inds, nbranches, solver, delta_t, children_in_level, parents_in_level, root_inds, branchpoint_group_inds, debug_states)
+ step_voltage_implicit_with_jaxley_spsolve(voltages, voltage_terms, constant_terms, axial_conductances, internal_node_inds, sinks, sources, types, nseg_per_branch, par_inds, child_inds, nbranches, solver, delta_t, idx, debug_states)
@@ -1967,7 +1951,14 @@
Source code in jaxley/solver_voltage.py
- 67
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
68
69
70
@@ -2111,33 +2102,7 @@ 208
209
210
-211
-212
-213
-214
-215
-216
-217
-218
-219
-220
-221
-222
-223
-224
-225
-226
-227
-228
-229
-230
-231
-232
-233
-234
-235
-236
-237 | def step_voltage_implicit_with_jaxley_spsolve(
+211
| def step_voltage_implicit_with_jaxley_spsolve(
voltages: jnp.ndarray,
voltage_terms: jnp.ndarray,
constant_terms: jnp.ndarray,
@@ -2146,37 +2111,31 @@ sinks: jnp.ndarray,
sources: jnp.ndarray,
types: jnp.ndarray,
- masked_node_inds: jnp.ndarray,
nseg_per_branch: jnp.ndarray,
- nseg: int,
par_inds: jnp.ndarray,
child_inds: jnp.ndarray,
nbranches: int,
solver: str,
delta_t: float,
- children_in_level: List[jnp.ndarray],
- parents_in_level: List[jnp.ndarray],
- root_inds: jnp.ndarray,
- branchpoint_group_inds: jnp.ndarray,
+ idx: JaxleySolveIndexer,
debug_states,
):
"""Solve one timestep of branched nerve equations with implicit (backward) Euler."""
# Build diagonals.
c2c = np.isin(types, [0, 1, 2])
- diags = jnp.ones(nbranches * nseg)
+ total_ncomp = idx.cumsum_nseg[-1]
+ diags = jnp.ones(total_ncomp)
# if-case needed because `.at` does not allow empty inputs, but the input is
# empty for compartments.
if len(sinks[c2c]) > 0:
- diags = diags.at[masked_node_inds[sinks[c2c]]].add(
- delta_t * axial_conductances[c2c]
- )
+ diags = diags.at[idx.mask(sinks[c2c])].add(delta_t * axial_conductances[c2c])
- diags = diags.at[masked_node_inds[internal_node_inds]].add(delta_t * voltage_terms)
+ diags = diags.at[idx.mask(internal_node_inds)].add(delta_t * voltage_terms)
# Build solves.
- solves = jnp.zeros(nbranches * nseg)
- solves = solves.at[masked_node_inds[internal_node_inds]].add(
+ solves = jnp.zeros(total_ncomp)
+ solves = solves.at[idx.mask(internal_node_inds)].add(
voltages + delta_t * constant_terms
)
@@ -2184,32 +2143,23 @@ c2c = types == 0 # c2c = compartment-to-compartment.
# Build uppers.
- uppers = jnp.zeros(nbranches * nseg)
+ uppers = jnp.zeros(total_ncomp)
upper_inds = sources[c2c] > sinks[c2c]
sinks_upper = sinks[c2c][upper_inds]
if len(sinks_upper) > 0:
- uppers = uppers.at[masked_node_inds[sinks_upper]].add(
+ uppers = uppers.at[idx.mask(sinks_upper)].add(
-delta_t * axial_conductances[c2c][upper_inds]
)
# Build lowers.
- lowers = jnp.zeros(nbranches * nseg)
+ lowers = jnp.zeros(total_ncomp)
lower_inds = sources[c2c] < sinks[c2c]
sinks_lower = sinks[c2c][lower_inds]
if len(sinks_lower) > 0:
- lowers = lowers.at[masked_node_inds[sinks_lower]].add(
+ lowers = lowers.at[idx.mask(sinks_lower)].add(
-delta_t * axial_conductances[c2c][lower_inds]
)
- # Reshape all diags, lowers, uppers, and solves into a "per-branch" format.
- diags = jnp.reshape(diags, (nbranches, -1))
- solves = jnp.reshape(solves, (nbranches, -1))
- uppers = jnp.reshape(uppers, (nbranches, -1))
- lowers = jnp.reshape(lowers, (nbranches, -1))
- # lowers and uppers were built to have length `nseg` above for simplicity.
- uppers = uppers[:, :-1]
- lowers = lowers[:, 1:]
-
# Build branchpoint conductances.
branchpoint_conds_parents = axial_conductances[types == 1]
branchpoint_conds_children = axial_conductances[types == 2]
@@ -2221,7 +2171,7 @@ # Find unique group identifiers
num_branchpoints = len(branchpoint_conds_parents)
branchpoint_diags = -group_and_sum(
- all_branchpoint_vals, branchpoint_group_inds, num_branchpoints
+ all_branchpoint_vals, idx.branchpoint_group_inds, num_branchpoints
)
branchpoint_solves = jnp.zeros((num_branchpoints,))
@@ -2274,10 +2224,8 @@ branchpoint_diags,
branchpoint_solves,
solver,
- children_in_level,
- parents_in_level,
- root_inds,
nseg_per_branch,
+ idx,
debug_states,
)
@@ -2301,13 +2249,11 @@ branchpoint_diags,
branchpoint_solves,
solver,
- children_in_level,
- parents_in_level,
- root_inds,
nseg_per_branch,
+ idx,
debug_states,
)
- return solves.ravel(order="C")[masked_node_inds[internal_node_inds]]
+ return solves.ravel(order="C")[idx.mask(internal_node_inds)]
|
diff --git a/dev/reference/modules/index.html b/dev/reference/modules/index.html
index 21b2b610..9c941f53 100644
--- a/dev/reference/modules/index.html
+++ b/dev/reference/modules/index.html
@@ -3455,13 +3455,7 @@ Module | | class Module(ABC):
"""Module base class.
Modules are everything that can be passed to `jx.integrate`, i.e. compartments,
@@ -3500,7 +3494,6 @@ Module
|
|