From 3c18ac72ce7afd72b6cea866710134bbc032fd70 Mon Sep 17 00:00:00 2001 From: Mainak Jas Date: Fri, 19 Aug 2022 15:06:30 -0400 Subject: [PATCH] ENH: add option to delete a drive --- examples/howto/plot_connectivity.py | 2 +- hnn_core/network.py | 30 ++++++++++++++++++----------- hnn_core/tests/test_network.py | 2 -- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/examples/howto/plot_connectivity.py b/examples/howto/plot_connectivity.py index eec6e747f..e7cb22051 100644 --- a/examples/howto/plot_connectivity.py +++ b/examples/howto/plot_connectivity.py @@ -71,7 +71,7 @@ # directly. def get_network(probability=1.0): net = jones_2009_model(add_drives_from_params=True) - net.clear_connectivity() + net.connectivity = list() # Pyramidal cell connections location, receptor = 'distal', 'ampa' diff --git a/hnn_core/network.py b/hnn_core/network.py index f5277f954..f66d55663 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -1253,23 +1253,31 @@ def add_connection(self, src_gids, target_gids, loc, receptor, self.connectivity.append(deepcopy(conn)) - def clear_connectivity(self): - """Remove all connections defined in Network.connectivity - """ + def _clear_connectivity(self, src_types=None): + """Remove connections with src_type in Network.connectivity.""" + if src_types is None: + src_types = self.external_drives.keys() connectivity = list() for conn in self.connectivity: - if conn['src_type'] in self.external_drives.keys(): + if conn['src_type'] in src_types: connectivity.append(conn) self.connectivity = connectivity - def clear_drives(self): - """Remove all drives defined in Network.connectivity""" + def clear_drives(self, drive_name='all'): + """Remove all drives defined in Network.connectivity. + + Parameters + ---------- + drive_names : list | 'all' + The drive_names to remove + """ + if drive_names == 'all': + drive_names = list(self.external_drives.keys()) + _validate_type(drive_names, (list,)) connectivity = list() - for conn in self.connectivity: - if conn['src_type'] not in self.external_drives.keys(): - connectivity.append(conn) - self.external_drives = dict() - self.connectivity = connectivity + for drive_name in drive_names: + del self.external_drives[drive_name] + self._clear_connectivity(src_type=drive_names) def add_electrode_array(self, name, electrode_pos, *, conductivity=0.3, method='psa', min_distance=0.5): diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index 249548805..6259a9ec4 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -499,8 +499,6 @@ def test_network(): # Test removing connections from net.connectivity # Needs to be updated if number of drives change in preceeding tests - net.clear_connectivity() - assert len(net.connectivity) == 50 net.clear_drives() assert len(net.connectivity) == 0