Skip to content

Commit

Permalink
Change recommended API for generating branches and cells
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 14, 2023
1 parent 4bda20c commit 17337a3
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
27 changes: 21 additions & 6 deletions neurax/modules/branch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional, Union

import jax.numpy as jnp
import numpy as np
Expand All @@ -14,14 +14,29 @@ class Branch(Module):
branch_params: Dict = {}
branch_states: Dict = {}

def __init__(self, compartments: List[Compartment]):
def __init__(
self,
compartments: Union[Compartment, List[Compartment]],
nseg: Optional[int] = None,
):
super().__init__()
assert (
isinstance(compartments, Compartment) or nseg is None
), "If `compartments` is a list then you cannot set `nseg`."
assert (
isinstance(compartments, List) or nseg is not None
), "If `compartments` is not a list then you have to set `nseg`."
self._init_params_and_state(self.branch_params, self.branch_states)
self._append_to_params_and_state(compartments)
for comp in compartments:
if isinstance(compartments, Compartment):
compartment_list = [compartments for _ in range(nseg)]
else:
compartment_list = compartments

self._append_to_params_and_state(compartment_list)
for comp in compartment_list:
self._append_to_channel_params_and_state(comp)

self.nseg = len(compartments)
self.nseg = len(compartment_list)
self.total_nbranches = 1
self.nbranches_per_cell = [1]
self.cumsum_nbranches = jnp.asarray([0, 1])
Expand All @@ -36,7 +51,7 @@ def __init__(self, compartments: List[Compartment]):
)

# Channel indexing.
for i, comp in enumerate(compartments):
for i, comp in enumerate(compartment_list):
index = pd.DataFrame.from_dict(
dict(comp_index=[i], branch_index=[0], cell_index=[0])
)
Expand Down
26 changes: 17 additions & 9 deletions neurax/modules/cell.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional, Union

import jax.numpy as jnp
import numpy as np
Expand All @@ -23,18 +23,26 @@ class Cell(Module):
cell_params: Dict = {}
cell_states: Dict = {}

def __init__(self, branches: List[Branch], parents: List):
def __init__(self, branches: Union[Branch, List[Branch]], parents: List):
super().__init__()
assert (
isinstance(branches, Branch) or len(parents) == len(branches) is None
), "If `branches` is a list then you have to provide equally many parents, i.e. len(branches) == len(parents)."
self._init_params_and_state(self.cell_params, self.cell_states)
self._append_to_params_and_state(branches)
for branch in branches:
if isinstance(branches, Branch):
branch_list = [branches for _ in range(len(parents))]
else:
branch_list = branches

self._append_to_params_and_state(branch_list)
for branch in branch_list:
self._append_to_channel_params_and_state(branch)

self.nseg = branches[0].nseg
self.total_nbranches = len(branches)
self.nbranches_per_cell = [len(branches)]
self.nseg = branch_list[0].nseg
self.total_nbranches = len(branch_list)
self.nbranches_per_cell = [len(branch_list)]
self.comb_parents = jnp.asarray(parents)
self.cumsum_nbranches = jnp.asarray([0, len(branches)])
self.cumsum_nbranches = jnp.asarray([0, len(branch_list)])

# Indexing.
self.nodes = pd.DataFrame(
Expand All @@ -48,7 +56,7 @@ def __init__(self, branches: List[Branch], parents: List):
)

# Channel indexing.
for i, branch in enumerate(branches):
for i, branch in enumerate(branch_list):
for channel in branch.channels:
name = type(channel).__name__
comp_inds = deepcopy(
Expand Down

0 comments on commit 17337a3

Please sign in to comment.