From 68dba2e021dcc6d2f9ae473deb809a0574336e25 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 16:38:55 +0100 Subject: [PATCH] introduce option to pass xyzr --- jaxley/modules/cell.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index c8832760..e10374a6 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -23,7 +23,19 @@ class Cell(Module): cell_params: Dict = {} cell_states: Dict = {} - def __init__(self, branches: Union[Branch, List[Branch]], parents: List): + def __init__( + self, + branches: Union[Branch, List[Branch]], + parents: List, + xyzr: Optional[Union[np.ndarray, jnp.ndarray]], + ): + """Initialize a cell. + + Args: + branches: + parents: + xyzr: The x, y, and z coordinates and the radius at these coordinates. + """ super().__init__() assert isinstance(branches, Branch) or len(parents) == len( branches @@ -249,7 +261,7 @@ def read_swc( comp = Compartment().initialize() branch = Branch([comp for _ in range(nseg)]).initialize() - cell = Cell([branch for _ in range(nbranches)], parents=parents) + cell = Cell([branch for _ in range(nbranches)], parents=parents, xyzr=None) radiuses = np.flip( np.asarray([radius_fns[b](range_) for b in range(len(parents))]), axis=1