Skip to content

Commit

Permalink
rename _to_jax -> to_jax
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 18, 2023
1 parent dab02f0 commit a5feb88
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 19 deletions.
2 changes: 1 addition & 1 deletion jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def integrate(
"""

assert module.initialized, "Module is not initialized, run `.initialize()`."
module._to_jax() # TODO(michaeldeistler): hide.
module.to_jax() # Creates `.jaxnodes` from `.nodes`.

if module.currents is not None:
# At least one stimulus was inserted.
Expand Down
13 changes: 11 additions & 2 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,22 @@ def __str__(self):
return f"jx.{type(self).__name__}"

def _append_params_and_states(self, param_dict, state_dict):
"""Insert the default params of the module (e.g. radius, length).
This is run at `__init__()`. It does not deal with channels.
"""
for param_name, param_value in param_dict.items():
self.nodes[param_name] = param_value
for state_name, state_value in state_dict.items():
self.nodes[state_name] = state_value

def _gather_channels_from_constituents(self, constituents: List) -> None:
"""Modifies `self.channels` and `self.nodes`."""
"""Modify `self.channels` and `self.nodes` with channel info from constituents.
This is run at `__init__()`. It takes all branches of constituents (e.g.
of all branches when the are assembled into a cell) and adds columns to
`.nodes` for the relevant channels.
"""
for module in constituents:
for channel in module.channels:
if type(channel).__name__ not in [
Expand All @@ -106,7 +115,7 @@ def _gather_channels_from_constituents(self, constituents: List) -> None:
name = type(channel).__name__
self.nodes.loc[self.nodes[name].isna(), name] = False

def _to_jax(self):
def to_jax(self):
self.jaxnodes = {}
for key, value in self.nodes.to_dict(orient="list").items():
self.jaxnodes[key] = jnp.asarray(value)
Expand Down
1 change: 0 additions & 1 deletion jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
self.cumsum_nbranches = jnp.asarray([0, 1])

# Indexing.
# TODO: need to take care of setting the `HH` column to False, not NaN.
self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)
self._append_params_and_states(self.branch_params, self.branch_states)
self.nodes["comp_index"] = np.arange(self.nseg).tolist()
Expand Down
2 changes: 1 addition & 1 deletion jaxley/utils/jax_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import math
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar

import pandas as pd
import jax
import jax.numpy as jnp
import pandas as pd

Carry = TypeVar("Carry")
Input = TypeVar("Input")
Expand Down
18 changes: 9 additions & 9 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

import jaxley as jx
from jaxley.channels import HH, Na, K
from jaxley.channels import HH, K, Na
from jaxley.synapses import GlutamateSynapse, TestSynapse


Expand Down Expand Up @@ -96,7 +96,7 @@ def test_diverse_synapse_types():
params[0]["gS"] = params[0]["gS"].at[:].set(2.2)
params[1]["gC"] = params[1]["gC"].at[0].set(3.3)
params[1]["gC"] = params[1]["gC"].at[1].set(4.4)
net._to_jax()
net.to_jax()
all_parameters = net.get_all_parameters(params)

assert np.all(all_parameters["radius"] == 1.0)
Expand All @@ -115,7 +115,7 @@ def test_diverse_synapse_types():

# Modify the trainable parameters.
params[2]["gS"] = params[2]["gS"].at[:].set(5.5)
net._to_jax()
net.to_jax()
all_parameters = net.get_all_parameters(params)
assert np.all(all_parameters["gS"][0] == 2.2)
assert np.all(all_parameters["gS"][1] == 5.5)
Expand Down Expand Up @@ -202,29 +202,29 @@ def get_params_subset_trainable(net):
net.cell(0).branch(1).make_trainable("HH_gNa")
params = net.get_parameters()
params[0]["HH_gNa"] = params[0]["HH_gNa"].at[:].set(0.0)
net._to_jax()
net.to_jax()
return net.get_all_parameters(trainable_params=params)


def get_params_set_subset(net):
net.cell(0).branch(1).set("HH_gNa", 0.0)
params = net.get_parameters()
net._to_jax()
net.to_jax()
return net.get_all_parameters(trainable_params=params)


def get_params_all_trainable(net):
net.cell("all").branch("all").comp("all").make_trainable("HH_gNa")
params = net.get_parameters()
params[0]["HH_gNa"] = params[0]["HH_gNa"].at[:].set(0.0)
net._to_jax()
net.to_jax()
return net.get_all_parameters(trainable_params=params)


def get_params_set(net):
net.set("HH_gNa", 0.0)
params = net.get_parameters()
net._to_jax()
net.to_jax()
return net.get_all_parameters(trainable_params=params)


Expand All @@ -236,15 +236,15 @@ def test_make_trainable_corresponds_to_set_pospischil():
net1.cell("all").branch("all").comp("all").make_trainable("vt")
params = net1.get_parameters()
params[0]["vt"] = params[0]["vt"].at[:].set(0.05)
net1._to_jax()
net1.to_jax()
params1 = net1.get_all_parameters(trainable_params=params)

net2.cell(0).insert(Na())
net2.insert(K())
net2.cell("all").branch("all").comp("all").make_trainable("vt")
params = net2.get_parameters()
params[0]["vt"] = params[0]["vt"].at[:].set(0.05)
net2._to_jax()
net2.to_jax()
params2 = net2.get_all_parameters(trainable_params=params)
assert np.array_equal(params1["vt"], params2["vt"], equal_nan=True)
assert np.array_equal(params1["Na_gNa"], params2["Na_gNa"], equal_nan=True)
Expand Down
6 changes: 1 addition & 5 deletions tests/test_shared_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
import numpy as np

import jaxley as jx
from jaxley.channels import HH, Na, K


import jax.numpy as jnp
from jaxley.channels import Channel
from jaxley.channels import HH, Channel, K, Na


class Dummy1(Channel):
Expand Down

0 comments on commit a5feb88

Please sign in to comment.