Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.comp and .loc addresses #285 #288

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,44 @@ def __init__(
# Coordinates.
self.xyzr = [float("NaN") * np.zeros((2, 4))]

def _compartment_view(self) -> CompartmentView:
view = deepcopy(self.nodes)
view["global_comp_index"] = view["comp_index"]
view["global_branch_index"] = view["branch_index"]
view["global_cell_index"] = view["cell_index"]
return CompartmentView(self, view)

@property
def comp(self) -> CompartmentView:
"""Return a compartment of the discretized branch.

Args:
index: integer or float between 0 to 1 for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is only 5 people hacking on Jaxley right now I do not think that we have to support legacy code. Please remove the legacy option.

legacy.
"""
return self._compartment_view()

def loc(self, loc: float) -> CompartmentView:
"""Return compartment of the discretized branch that is
closest to the provided location.

Args:
loc: float between 0 and 1
"""
return self.comp.loc(loc)

def __getattr__(self, key):
# Ensure that hidden methods such as `__deepcopy__` still work.
if key.startswith("__"):
return super().__getattribute__(key)

if key == "comp":
view = deepcopy(self.nodes)
view["global_comp_index"] = view["comp_index"]
view["global_branch_index"] = view["branch_index"]
view["global_cell_index"] = view["cell_index"]
return CompartmentView(self, view)
elif key in self.group_nodes:
# if key == "comp":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning this up! (please remove the comments)

# view = deepcopy(self.nodes)
# view["global_comp_index"] = view["comp_index"]
# view["global_branch_index"] = view["branch_index"]
# view["global_cell_index"] = view["cell_index"]
# return CompartmentView(self, view)
if key in self.group_nodes:
inds = self.group_nodes[key].index.values
view = self.nodes.loc[inds]
view["global_comp_index"] = view["comp_index"]
Expand Down Expand Up @@ -138,6 +164,28 @@ def __call__(self, index: float):
new_view.view["comp_index"] -= new_view.view["comp_index"].iloc[0]
return new_view

def __getattr__(self, key):
assert key == "comp"
def _compartment_view(self) -> CompartmentView:
return CompartmentView(self.pointer, self.view)

@property
def comp(self):
"""Return a compartment of the discretized branch.

Args:
index: integer or float between 0 to 1 for
legacy.
"""
return self._compartment_view()

def loc(self, loc: float):
"""Return compartment of the branch that is
closest to the continuous float location between 0 and 1.

Args:
loc: float between 0 and 1
"""
return self.comp.loc(loc)

# def __getattr__(self, key):
# assert key == "comp"
# return CompartmentView(self.pointer, self.view)
32 changes: 30 additions & 2 deletions jaxley/modules/compartment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union
import warnings

import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -63,7 +64,34 @@ def __init__(self, pointer, view):
view = view.assign(controlled_by_param=view.comp_index)
super().__init__(pointer, view)

def __call__(self, loc: float):
def __call__(self, index: Union[float, int]) -> "CompartmentView":
"""Selects a specific compartment with integer indexing from a view onto all compartments.

The resulting object will also be a CompartmentView.
"""
if index == "all":
pass
else:
# support for legacy code
if isinstance(index, float) and 0 <= index <= 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is only 5 people hacking on Jaxley right now I do not think that we have to support legacy code. Please remove this.

# map the float to an int index
# i.e. the range [0, 1] to [0, N-1] where N is the number of segments
mapped_index = index_of_loc(0, index, self.pointer.nseg)
warnings.warn("Float values for 'index' are deprecated and will be removed in future versions. "
"Use an integer index instead.", DeprecationWarning)
index = mapped_index
elif not isinstance(index, int):
raise ValueError("Index must be an integer or a float between 0 and 1.")
assert (
index >= 0 and index < self.pointer.nseg
), f"Compartments must be indexed by a discrete value between 0 and {self.pointer.nseg - 1}. Provided was {index}."
return super().adjust_view("comp_index", index)

def loc(self, loc: float) -> "CompartmentView":
"""Selects a specific compartment with relative location indexing from a view onto all compartments.

The resulting object will also be a CompartmentView.
"""
if loc != "all":
assert (
loc >= 0.0 and loc <= 1.0
Expand Down
36 changes: 36 additions & 0 deletions tests/test_api_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import jaxley as jx
from jaxley.synapses import IonotropicSynapse
from jaxley.utils.cell_utils import index_of_loc


def test_api_equivalence_morphology():
Expand Down Expand Up @@ -43,6 +44,41 @@ def test_api_equivalence_morphology():
jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8
), "Voltages do not match between morphology APIs."

def test_api_equivalence_comp_loc():
"""Test the API for comp and loc indexing."""
nseg_per_branch = 10
depth = 2
dt = 0.025

parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = jx.Compartment().initialize()

branch1 = jx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell1 = jx.Cell(
[branch1 for _ in range(num_branches)], parents=parents
).initialize()

branch2 = jx.Branch(comp, nseg=nseg_per_branch).initialize()
cell2 = jx.Cell(branch2, parents=parents).initialize()

loc_record = 0.4
cell1.branch(2).loc(loc_record).record()
cell2.branch(2).comp(int(index_of_loc(0, loc_record, cell2.branch(2).comp.pointer.nseg))).record()

loc_stimulate = 1.0
current = jx.step_current(0.5, 1.0, 1.0, dt, 3.0)
cell1.branch(1).comp(int(index_of_loc(0, loc_stimulate, cell1.branch(1).comp.pointer.nseg))).stimulate(current)
cell2.branch(1).loc(loc_stimulate).stimulate(current)

voltages1 = jx.integrate(cell1, delta_t=dt)
voltages2 = jx.integrate(cell2, delta_t=dt)
assert (
jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8
), "Voltages do not match between morphology APIs."


def test_api_equivalence_synapses():
"""Test whether ways of adding synapses are equivalent."""
Expand Down
Loading