From 490ac1f2f60126ed2053dce8859fae22308dae66 Mon Sep 17 00:00:00 2001 From: Mainak Jas Date: Wed, 23 Jun 2021 20:56:41 -0400 Subject: [PATCH 1/6] ENH: interactive connectivity --- hnn_core/viz.py | 103 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 73 insertions(+), 30 deletions(-) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index a875cb754..f2ee4d3ba 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -836,7 +836,41 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, return ax.get_figure() -def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, +def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, + src_range, target_range, nc_dict, colormap): + from .cell import _get_gaussian_connection + + # Extract indeces to get position in network + # Index in gid range aligns with net.pos_dict + target_src_pair = conn['gid_pairs'][src_gid] + target_indeces = np.where(np.in1d(target_range, target_src_pair))[0] + + src_idx = np.where(src_range == src_gid)[0][0] + src_pos = src_type_pos[src_idx] + + # Aggregate positions and weight of each connected target + weights, target_x_pos, target_y_pos = list(), list(), list() + for target_idx in target_indeces: + target_pos = target_type_pos[target_idx] + target_x_pos.append(target_pos[0]) + target_y_pos.append(target_pos[1]) + weight, _ = _get_gaussian_connection(src_pos, target_pos, nc_dict) + weights.append(weight) + + ax.clear() + im = ax.scatter(target_x_pos, target_y_pos, c=weights, s=50, + cmap=colormap) + x_pos = [target_type_pos[idx][0] for idx in range(len(target_type_pos))] + y_pos = [target_type_pos[idx][1] for idx in range(len(target_type_pos))] + + ax.scatter(x_pos, y_pos, color='k', marker='x', zorder=-1, s=20) + ax.scatter(src_pos[0], src_pos[1], marker='s', color='red', s=150) + ax.set_ylabel('Y Position') + ax.set_xlabel('X Position') + return im + + +def plot_cell_connectivity(net, conn_idx, src_gid, axes=None, colorbar=True, colormap='viridis', show=True): """Plot synaptic weight of connections originating from src_gid. @@ -849,7 +883,7 @@ def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, from `net.connectivity` src_gid : int Each cell in a network is uniquely identified by it's "global ID": GID. - ax : instance of Axes3D + axes : instance of Axes3D Matplotlib 3D axis colormap : str The name of a matplotlib colormap. Default: 'viridis' @@ -881,8 +915,6 @@ def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, _validate_type(net, Network, 'net', 'Network') _validate_type(conn_idx, int, 'conn_idx', 'int') _validate_type(src_gid, int, 'src_gid', 'int') - if ax is None: - _, ax = plt.subplots(1, 1) # Load objects for distance calculation conn = net.connectivity[conn_idx] @@ -900,38 +932,46 @@ def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, target_range = np.array(conn['target_range']) - # Extract indeces to get position in network - # Index in gid range aligns with net.pos_dict - target_src_pair = conn['gid_pairs'][src_gid] - target_indeces = np.where(np.in1d(target_range, target_src_pair))[0] - - src_idx = np.where(src_range == src_gid)[0][0] - src_pos = src_type_pos[src_idx] - - # Aggregate positions and weight of each connected target - weights, target_x_pos, target_y_pos = list(), list(), list() - for target_idx in target_indeces: - target_pos = target_type_pos[target_idx] - target_x_pos.append(target_pos[0]) - target_y_pos.append(target_pos[1]) - weight, _ = _get_gaussian_connection(src_pos, target_pos, nc_dict) - weights.append(weight) + if axes is None: + if src_type in net.cell_types: + fig, axes = plt.subplots(1, 2, sharex=True, sharey=True) + else: + fig, axes = plt.subplots(1, 1, sharex=True, sharey=True) + axes = [axes] - im = ax.scatter(target_x_pos, target_y_pos, c=weights, s=50, cmap=colormap) + if len(axes) == 2: + ax_src, ax = axes + else: + ax = axes[0] - # Gather positions of all gids in target_type. - x_pos = [target_type_pos[idx][0] for idx in range(len(target_type_pos))] - y_pos = [target_type_pos[idx][1] for idx in range(len(target_type_pos))] - ax.scatter(x_pos, y_pos, color='k', marker='x', zorder=-1, s=20) + im = _update_target_plot(ax, conn, src_gid, src_type_pos, + target_type_pos, src_range, + target_range, nc_dict, colormap) - # Only plot src_gid if proper cell type. + x_src = [src_pos[0] for src_pos in src_type_pos] + y_src = [src_pos[1] for src_pos in src_type_pos] if src_type in net.cell_types: - ax.scatter(src_pos[0], src_pos[1], marker='s', color='red', s=150) - ax.set_ylabel('Y Position') - ax.set_xlabel('X Position') - ax.set_title(f"{conn['src_type']}-> {conn['target_type']}" + ax_src.scatter(x_src, y_src, marker='s', color='red', s=50) + + plt.suptitle(f"{conn['src_type']}-> {conn['target_type']}" f" ({conn['loc']}, {conn['receptor']})") + def _onclick(event): + if event.inaxes in [ax]: + return + + dist = np.linalg.norm(np.array(src_type_pos)[:, :2] - + np.array([event.xdata, event.ydata]), + axis=1) + src_idx = np.argmin(dist) + + src_gid = src_range[src_idx] + _update_target_plot(ax, conn, src_gid, src_type_pos, + target_type_pos, src_range, target_range, + nc_dict, colormap) + + fig.canvas.draw() + if colorbar: fig = ax.get_figure() xfmt = ScalarFormatter() @@ -941,5 +981,8 @@ def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, cbar.ax.set_ylabel('Weight', rotation=-90, va="bottom") plt.tight_layout() + + fig.canvas.mpl_connect('button_press_event', _onclick) + plt_show(show) return ax.get_figure(), ax From 39ad1bce6fbeb8bd65f26366e84a172e40d5251b Mon Sep 17 00:00:00 2001 From: Mainak Jas Date: Fri, 2 Jul 2021 10:13:52 -0400 Subject: [PATCH 2/6] FIX: not all gids in src_ranges may be valid --- hnn_core/viz.py | 51 +++++++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index f2ee4d3ba..1bfbdb00e 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -860,9 +860,8 @@ def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, ax.clear() im = ax.scatter(target_x_pos, target_y_pos, c=weights, s=50, cmap=colormap) - x_pos = [target_type_pos[idx][0] for idx in range(len(target_type_pos))] - y_pos = [target_type_pos[idx][1] for idx in range(len(target_type_pos))] - + x_pos = target_type_pos[:, 0] + y_pos = target_type_pos[:, 1] ax.scatter(x_pos, y_pos, color='k', marker='x', zorder=-1, s=20) ax.scatter(src_pos[0], src_pos[1], marker='s', color='red', s=150) ax.set_ylabel('Y Position') @@ -870,8 +869,8 @@ def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, return im -def plot_cell_connectivity(net, conn_idx, src_gid, axes=None, colorbar=True, - colormap='viridis', show=True): +def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, + colorbar=True, colormap='viridis', show=True): """Plot synaptic weight of connections originating from src_gid. Parameters @@ -881,8 +880,10 @@ def plot_cell_connectivity(net, conn_idx, src_gid, axes=None, colorbar=True, conn_idx : int Index of connection to be visualized from `net.connectivity` - src_gid : int - Each cell in a network is uniquely identified by it's "global ID": GID. + src_gid : int | None + The cell ID of the source cell. It must be an element of + net.connectivity[conn_idx]['gid_pairs'].keys() + If None, the first src_gid from the list of valid src_gids is selected. axes : instance of Axes3D Matplotlib 3D axis colormap : str @@ -914,21 +915,26 @@ def plot_cell_connectivity(net, conn_idx, src_gid, axes=None, colorbar=True, _validate_type(net, Network, 'net', 'Network') _validate_type(conn_idx, int, 'conn_idx', 'int') - _validate_type(src_gid, int, 'src_gid', 'int') # Load objects for distance calculation conn = net.connectivity[conn_idx] nc_dict = conn['nc_dict'] src_type = conn['src_type'] target_type = conn['target_type'] - src_type_pos = net.pos_dict[src_type] - target_type_pos = net.pos_dict[target_type] - + src_type_pos = np.array(net.pos_dict[src_type]) + target_type_pos = np.array(net.pos_dict[target_type]) src_range = np.array(conn['src_range']) - if src_gid not in src_range: - raise ValueError(f'src_gid not in the src type range of {src_type} ' - f'gids. Valid gids include {src_range[0]} -> ' - f'{src_range[-1]}') + + valid_src_gids = list(net.connectivity[conn_idx]['gid_pairs'].keys()) + src_pos_valid = src_type_pos[np.in1d(src_range, valid_src_gids)] + + if src_gid is None: + src_gid = valid_src_gids[0] + _validate_type(src_gid, int, 'src_gid', 'int') + + if src_gid not in valid_src_gids: + raise ValueError(f'src_gid not a valid cell ID for this connection ' + f'Please select one of {valid_src_gids}') target_range = np.array(conn['target_range']) @@ -948,10 +954,15 @@ def plot_cell_connectivity(net, conn_idx, src_gid, axes=None, colorbar=True, target_type_pos, src_range, target_range, nc_dict, colormap) - x_src = [src_pos[0] for src_pos in src_type_pos] - y_src = [src_pos[1] for src_pos in src_type_pos] + x_src = src_type_pos[:, 0] + y_src = src_type_pos[:, 1] + x_src_valid = src_pos_valid[:, 0] + y_src_valid = src_pos_valid[:, 1] if src_type in net.cell_types: - ax_src.scatter(x_src, y_src, marker='s', color='red', s=50) + ax_src.scatter(x_src, y_src, marker='s', color='red', s=50, + alpha=0.2) + ax_src.scatter(x_src_valid, y_src_valid, marker='s', color='red', + s=50) plt.suptitle(f"{conn['src_type']}-> {conn['target_type']}" f" ({conn['loc']}, {conn['receptor']})") @@ -960,12 +971,14 @@ def _onclick(event): if event.inaxes in [ax]: return - dist = np.linalg.norm(np.array(src_type_pos)[:, :2] - + dist = np.linalg.norm(src_type_pos[:, :2] - np.array([event.xdata, event.ydata]), axis=1) src_idx = np.argmin(dist) src_gid = src_range[src_idx] + if src_gid not in valid_src_gids: + return _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, src_range, target_range, nc_dict, colormap) From d3c7c928b2b3c1a836441c74332ec79addd68584 Mon Sep 17 00:00:00 2001 From: Mainak Jas Date: Fri, 2 Jul 2021 12:03:42 -0400 Subject: [PATCH 3/6] Flake8 --- hnn_core/viz.py | 1 - 1 file changed, 1 deletion(-) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 1bfbdb00e..76a95bb28 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -910,7 +910,6 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, """ import matplotlib.pyplot as plt from .network import Network - from .cell import _get_gaussian_connection from matplotlib.ticker import ScalarFormatter _validate_type(net, Network, 'net', 'Network') From 9380d5cc1804ec37fcb10661bbafed4afc5756aa Mon Sep 17 00:00:00 2001 From: Mainak Jas Date: Fri, 2 Jul 2021 12:37:02 -0400 Subject: [PATCH 4/6] TST with fake clicks --- examples/howto/plot_connectivity.py | 2 +- hnn_core/tests/test_viz.py | 22 +++++++++++++++++++++- hnn_core/viz.py | 4 ++-- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/examples/howto/plot_connectivity.py b/examples/howto/plot_connectivity.py index 4584628b9..e2bd906d1 100644 --- a/examples/howto/plot_connectivity.py +++ b/examples/howto/plot_connectivity.py @@ -51,7 +51,7 @@ gid_idx = 11 src_gid = net_erp.connectivity[conn_idx]['src_range'][gid_idx] -fig, ax = plot_cell_connectivity(net_erp, conn_idx, src_gid) +fig = plot_cell_connectivity(net_erp, conn_idx, src_gid) ############################################################################### # Data recorded during simulations are stored under diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index 158531100..03d7e3678 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -1,3 +1,4 @@ +from functools import partial import os.path as op import matplotlib @@ -13,6 +14,13 @@ matplotlib.use('agg') +def _fake_click(fig, ax, point, button=1): + """Fake a click at a point within axes.""" + x, y = ax.transData.transform_point(point) + func = partial(fig.canvas.button_press_event, x=x, y=y, button=button) + func(guiEvent=None) + + def test_network_visualization(): """Test network visualisations.""" hnn_core_root = op.dirname(hnn_core.__file__) @@ -47,9 +55,21 @@ def test_network_visualization(): with pytest.raises(TypeError, match='src_gid must be an instance of'): plot_cell_connectivity(net, conn_idx, src_gid='blah') - with pytest.raises(ValueError, match='src_gid not in the'): + with pytest.raises(ValueError, match='src_gid not a valid cell ID'): plot_cell_connectivity(net, conn_idx, src_gid=-1) + # smoke test interactive clicking + del net.connectivity[-1] + conn_idx = 15 + net.add_connection(net.gid_ranges['L2_pyramidal'][::2], + 'L5_basket', 'soma', + 'ampa', 0.00025, 1.0, lamtha=3.0, + probability=0.8) + fig = plot_cell_connectivity(net, conn_idx) + ax_src = fig.axes[0] + pos = net.pos_dict['L2_pyramidal'][2] + _fake_click(fig, ax_src, [pos[0], pos[1]]) + def test_dipole_visualization(): """Test dipole visualisations.""" diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 76a95bb28..e8335a95c 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -967,7 +967,7 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, f" ({conn['loc']}, {conn['receptor']})") def _onclick(event): - if event.inaxes in [ax]: + if event.inaxes in [ax] or event.inaxes is None: return dist = np.linalg.norm(src_type_pos[:, :2] - @@ -997,4 +997,4 @@ def _onclick(event): fig.canvas.mpl_connect('button_press_event', _onclick) plt_show(show) - return ax.get_figure(), ax + return ax.get_figure() From da2b9e73f31638ce3f8d996611a530fa81e16c98 Mon Sep 17 00:00:00 2001 From: Mainak Jas Date: Fri, 2 Jul 2021 14:56:12 -0400 Subject: [PATCH 5/6] DOC: update whats new --- doc/whats_new.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 60361c7a1..2d7ed0d41 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -18,7 +18,7 @@ Changelog - Add new function :func:`~hnn_core.viz.plot_cell_morphology` to visualize cell morphology, by `Mainak Jas`_ in `#319 `_ -- Compute dipole component in z-direction automatically from cell morphology instead of hard coding, by `Mainak Jas`_ in `#327 `_ +- Compute dipole component in z-direction automatically from cell morphology instead of hard coding, by `Mainak Jas`_ in `#327 `_ - Store :class:`~hnn_core.Cell` instances in :class:`~hnn_core.Network`'s :attr:`~/hnn_core.Network.cells` attribute by `Ryan Thorpe`_ in `#321 `_ @@ -32,6 +32,8 @@ Changelog - Previously published models can now be loaded via ``net=law_2021_model()`` and ``jones_2009_model()``, by `Nick Tolley`_ in `#348 `_ +- Add ability to interactivity explore connections in :func:`~hnn_core.viz.plot_cell_connectivity` by `Mainak Jas`_ in `#376 `_ + Bug ~~~ From b1fe6719778a657c14ec83d85ffd51eb56a6a6ba Mon Sep 17 00:00:00 2001 From: Mainak Jas Date: Tue, 6 Jul 2021 17:16:42 -0400 Subject: [PATCH 6/6] FIX: address nick comments --- hnn_core/tests/test_viz.py | 10 +++++++--- hnn_core/viz.py | 24 ++++++++++++++---------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index 03d7e3678..7edc46ba1 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -3,6 +3,7 @@ import matplotlib import numpy as np +from numpy.testing import assert_allclose import pytest import hnn_core @@ -55,10 +56,10 @@ def test_network_visualization(): with pytest.raises(TypeError, match='src_gid must be an instance of'): plot_cell_connectivity(net, conn_idx, src_gid='blah') - with pytest.raises(ValueError, match='src_gid not a valid cell ID'): + with pytest.raises(ValueError, match='src_gid -1 not a valid cell ID'): plot_cell_connectivity(net, conn_idx, src_gid=-1) - # smoke test interactive clicking + # test interactive clicking updates the position of src_cell in plot del net.connectivity[-1] conn_idx = 15 net.add_connection(net.gid_ranges['L2_pyramidal'][::2], @@ -66,9 +67,12 @@ def test_network_visualization(): 'ampa', 0.00025, 1.0, lamtha=3.0, probability=0.8) fig = plot_cell_connectivity(net, conn_idx) - ax_src = fig.axes[0] + ax_src, ax_target, _ = fig.axes + pos = net.pos_dict['L2_pyramidal'][2] _fake_click(fig, ax_src, [pos[0], pos[1]]) + pos_in_plot = ax_target.collections[2].get_offsets().data[0] + assert_allclose(pos[:2], pos_in_plot) def test_dipole_visualization(): diff --git a/hnn_core/viz.py b/hnn_core/viz.py index e8335a95c..0c42a81b8 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -871,19 +871,23 @@ def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, colorbar=True, colormap='viridis', show=True): - """Plot synaptic weight of connections originating from src_gid. + """Plot synaptic weight of connections. + + This is an interactive plot with source cells shown in the left + subplot and connectivity from a source cell to all the target cells + in the right subplot. Click on the cells in the left subplot to + explore how the connectivity pattern changes for different source cells. Parameters ---------- net : Instance of Network object The Network object conn_idx : int - Index of connection to be visualized - from `net.connectivity` + Index of connection to be visualized from net.connectivity src_gid : int | None The cell ID of the source cell. It must be an element of net.connectivity[conn_idx]['gid_pairs'].keys() - If None, the first src_gid from the list of valid src_gids is selected. + If None, the first cell from the list of valid src_gids is selected. axes : instance of Axes3D Matplotlib 3D axis colormap : str @@ -900,11 +904,11 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, Notes ----- - Target cells will be determined by the connection class given by + Target cells will be determined by the connections in net.connectivity[conn_idx]. - If the target cell is not connected to src_gid, it will appear as a - smaller black circle. - src_gid is plotted as a red circle. src_gid will not be plotted if + If the target cell is not connected to the source cell, + it will appear as a smaller black cross. + Source cell is plotted as a red square. Source cell will not be plotted if the connection corresponds to a drive, ex: poisson, bursty, etc. """ @@ -932,8 +936,8 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, _validate_type(src_gid, int, 'src_gid', 'int') if src_gid not in valid_src_gids: - raise ValueError(f'src_gid not a valid cell ID for this connection ' - f'Please select one of {valid_src_gids}') + raise ValueError(f'src_gid {src_gid} not a valid cell ID for this ' + f'connection. Please select one of {valid_src_gids}') target_range = np.array(conn['target_range'])