Skip to content

Commit

Permalink
add uninsert method (#521)
Browse files Browse the repository at this point in the history
* add: add uninsert method

* add: add tests

* fix: rename uninsert -> delete_channel
  • Loading branch information
jnsbck authored Nov 21, 2024
1 parent 5e3637c commit 8ce2428
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
22 changes: 22 additions & 0 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,6 +1735,28 @@ def insert(self, channel: Channel):
for key in channel.channel_states:
self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]

def delete_channel(self, channel: Channel):
"""Remove a channel from the module.
Args:
channel: The channel to remove."""
name = channel._name
channel_names = [c._name for c in self.channels]
all_channel_names = [c._name for c in self.base.channels]
if name in channel_names:
channel_cols = list(channel.channel_params.keys())
channel_cols += list(channel.channel_states.keys())
self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan")
self.base.nodes.loc[self._nodes_in_view, name] = False

# only delete cols if no other comps in the module have the same channel
if np.all(~self.base.nodes[name]):
self.base.channels.pop(all_channel_names.index(name))
self.base.membrane_current_names.remove(channel.current_name)
self.base.nodes.drop(columns=channel_cols + [name], inplace=True)
else:
raise ValueError(f"Channel {name} not found in the module.")

@only_allow_module
def step(
self,
Expand Down
60 changes: 60 additions & 0 deletions tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,63 @@ def compute_current(self, states, v, params):
num_channels = 2
target = (t_max // dt + 2) * 0.001 * 0.01 * num_channels
assert np.abs(target - s[0, -1]) < 1e-8


def test_delete_channel(SimpleBranch):
# test complete removal of a channel from a module
branch1 = SimpleBranch(nseg=3)
branch1.comp(0).insert(K())
branch1.delete_channel(K())

branch2 = SimpleBranch(nseg=3)
branch2.comp(0).insert(K())
branch2.comp(0).delete_channel(K())

branch3 = SimpleBranch(nseg=3)
branch3.insert(K())
branch3.delete_channel(K())

def channel_present(view, channel, partial=False):
states_and_params = list(channel.channel_states.keys()) + list(
channel.channel_params.keys()
)
# none of the states or params should be in nodes
cols = view.nodes.columns.to_list()
channel_cols = [
col
for col in cols
if col.startswith(channel._name) and col != channel._name
]
diff = set(channel_cols).difference(set(states_and_params))
has_params_or_states = len(diff) > 0
has_channel_col = channel._name in view.nodes.columns
has_channel = channel._name in [c._name for c in view.channels]
has_mem_current = channel.current_name in view.membrane_current_names
if partial:
all_nans = (
not view.nodes[channel_cols].isna().all().all()
& ~view.nodes[channel._name].all()
)
return has_channel or has_mem_current or all_nans
return has_params_or_states or has_channel_col or has_channel or has_mem_current

for branch in [branch1, branch2, branch3]:
assert len(branch.channels) == 0
assert not channel_present(branch, K())

# test correct channels are removed only in the viewed part of the module
branch4 = SimpleBranch(nseg=3)
branch4.insert(HH())
branch4.comp(0).insert(K())
branch4.comp([1, 2]).insert(Leak())

branch4.comp(1).delete_channel(Leak())
# assert K in comp 0 and Leak still present in branch
assert channel_present(branch4.comp(0), K())
assert channel_present(branch4.comp(2), Leak(), partial=True)
assert not channel_present(branch4.comp(1), Leak(), partial=True)
assert channel_present(branch4, Leak())

branch4.comp(2).delete_channel(Leak())
# assert no more Leak
assert not channel_present(branch4, Leak())

0 comments on commit 8ce2428

Please sign in to comment.