-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] ENH: interactive connectivity #376
Changes from all commits
490ac1f
39ad1bc
d3c7c92
9380d5c
da2b9e7
b1fe671
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from functools import partial | ||
import os.path as op | ||
|
||
import matplotlib | ||
import numpy as np | ||
from numpy.testing import assert_allclose | ||
import pytest | ||
|
||
import hnn_core | ||
|
@@ -13,6 +15,13 @@ | |
matplotlib.use('agg') | ||
|
||
|
||
def _fake_click(fig, ax, point, button=1): | ||
"""Fake a click at a point within axes.""" | ||
x, y = ax.transData.transform_point(point) | ||
func = partial(fig.canvas.button_press_event, x=x, y=y, button=button) | ||
func(guiEvent=None) | ||
|
||
|
||
def test_network_visualization(): | ||
"""Test network visualisations.""" | ||
hnn_core_root = op.dirname(hnn_core.__file__) | ||
|
@@ -47,9 +56,24 @@ def test_network_visualization(): | |
with pytest.raises(TypeError, match='src_gid must be an instance of'): | ||
plot_cell_connectivity(net, conn_idx, src_gid='blah') | ||
|
||
with pytest.raises(ValueError, match='src_gid not in the'): | ||
with pytest.raises(ValueError, match='src_gid -1 not a valid cell ID'): | ||
plot_cell_connectivity(net, conn_idx, src_gid=-1) | ||
|
||
# test interactive clicking updates the position of src_cell in plot | ||
del net.connectivity[-1] | ||
conn_idx = 15 | ||
net.add_connection(net.gid_ranges['L2_pyramidal'][::2], | ||
'L5_basket', 'soma', | ||
'ampa', 0.00025, 1.0, lamtha=3.0, | ||
probability=0.8) | ||
fig = plot_cell_connectivity(net, conn_idx) | ||
ax_src, ax_target, _ = fig.axes | ||
|
||
pos = net.pos_dict['L2_pyramidal'][2] | ||
_fake_click(fig, ax_src, [pos[0], pos[1]]) | ||
pos_in_plot = ax_target.collections[2].get_offsets().data[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very fancy! Great unit test |
||
assert_allclose(pos[:2], pos_in_plot) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test to make sure the figure is actually updated? A really simple one would be to grab the title with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you're making me work harder for my merges ;-) It's good! Take a look at the last commit |
||
|
||
def test_dipole_visualization(): | ||
"""Test dipole visualisations.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -836,20 +836,59 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, | |
return ax.get_figure() | ||
|
||
|
||
def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, | ||
colormap='viridis', show=True): | ||
"""Plot synaptic weight of connections originating from src_gid. | ||
def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, | ||
src_range, target_range, nc_dict, colormap): | ||
from .cell import _get_gaussian_connection | ||
|
||
# Extract indeces to get position in network | ||
# Index in gid range aligns with net.pos_dict | ||
target_src_pair = conn['gid_pairs'][src_gid] | ||
target_indeces = np.where(np.in1d(target_range, target_src_pair))[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great solution, much cleaner There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's your code, I didn't do anything ... just moved it around :) |
||
|
||
src_idx = np.where(src_range == src_gid)[0][0] | ||
src_pos = src_type_pos[src_idx] | ||
|
||
# Aggregate positions and weight of each connected target | ||
weights, target_x_pos, target_y_pos = list(), list(), list() | ||
for target_idx in target_indeces: | ||
target_pos = target_type_pos[target_idx] | ||
target_x_pos.append(target_pos[0]) | ||
target_y_pos.append(target_pos[1]) | ||
weight, _ = _get_gaussian_connection(src_pos, target_pos, nc_dict) | ||
weights.append(weight) | ||
|
||
ax.clear() | ||
im = ax.scatter(target_x_pos, target_y_pos, c=weights, s=50, | ||
cmap=colormap) | ||
x_pos = target_type_pos[:, 0] | ||
y_pos = target_type_pos[:, 1] | ||
ax.scatter(x_pos, y_pos, color='k', marker='x', zorder=-1, s=20) | ||
ax.scatter(src_pos[0], src_pos[1], marker='s', color='red', s=150) | ||
ax.set_ylabel('Y Position') | ||
ax.set_xlabel('X Position') | ||
return im | ||
|
||
|
||
def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update docstring to indicate interactive plot? Also it'd be good to add a notes section describing how to use it. |
||
colorbar=True, colormap='viridis', show=True): | ||
"""Plot synaptic weight of connections. | ||
|
||
This is an interactive plot with source cells shown in the left | ||
subplot and connectivity from a source cell to all the target cells | ||
in the right subplot. Click on the cells in the left subplot to | ||
explore how the connectivity pattern changes for different source cells. | ||
|
||
Parameters | ||
---------- | ||
net : Instance of Network object | ||
The Network object | ||
conn_idx : int | ||
Index of connection to be visualized | ||
from `net.connectivity` | ||
src_gid : int | ||
Each cell in a network is uniquely identified by it's "global ID": GID. | ||
ax : instance of Axes3D | ||
Index of connection to be visualized from net.connectivity | ||
src_gid : int | None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good description, but after writing so many docstrings it's making me think it will be worthwhile creating some consistency for the arguments that appear super often. Sort of in line with the hnn-core glossary we had talked about, these parameters could start with the boilerplate description, and then more function specific text. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will defer this to a separate PR: #387 :) |
||
The cell ID of the source cell. It must be an element of | ||
net.connectivity[conn_idx]['gid_pairs'].keys() | ||
If None, the first cell from the list of valid src_gids is selected. | ||
axes : instance of Axes3D | ||
Matplotlib 3D axis | ||
colormap : str | ||
The name of a matplotlib colormap. Default: 'viridis' | ||
|
@@ -865,73 +904,90 @@ def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, | |
|
||
Notes | ||
----- | ||
Target cells will be determined by the connection class given by | ||
Target cells will be determined by the connections in | ||
net.connectivity[conn_idx]. | ||
If the target cell is not connected to src_gid, it will appear as a | ||
smaller black circle. | ||
src_gid is plotted as a red circle. src_gid will not be plotted if | ||
If the target cell is not connected to the source cell, | ||
it will appear as a smaller black cross. | ||
Source cell is plotted as a red square. Source cell will not be plotted if | ||
the connection corresponds to a drive, ex: poisson, bursty, etc. | ||
|
||
""" | ||
import matplotlib.pyplot as plt | ||
from .network import Network | ||
from .cell import _get_gaussian_connection | ||
from matplotlib.ticker import ScalarFormatter | ||
|
||
_validate_type(net, Network, 'net', 'Network') | ||
_validate_type(conn_idx, int, 'conn_idx', 'int') | ||
_validate_type(src_gid, int, 'src_gid', 'int') | ||
if ax is None: | ||
_, ax = plt.subplots(1, 1) | ||
|
||
# Load objects for distance calculation | ||
conn = net.connectivity[conn_idx] | ||
nc_dict = conn['nc_dict'] | ||
src_type = conn['src_type'] | ||
target_type = conn['target_type'] | ||
src_type_pos = net.pos_dict[src_type] | ||
target_type_pos = net.pos_dict[target_type] | ||
|
||
src_type_pos = np.array(net.pos_dict[src_type]) | ||
target_type_pos = np.array(net.pos_dict[target_type]) | ||
src_range = np.array(conn['src_range']) | ||
if src_gid not in src_range: | ||
raise ValueError(f'src_gid not in the src type range of {src_type} ' | ||
f'gids. Valid gids include {src_range[0]} -> ' | ||
f'{src_range[-1]}') | ||
|
||
target_range = np.array(conn['target_range']) | ||
valid_src_gids = list(net.connectivity[conn_idx]['gid_pairs'].keys()) | ||
src_pos_valid = src_type_pos[np.in1d(src_range, valid_src_gids)] | ||
|
||
# Extract indeces to get position in network | ||
# Index in gid range aligns with net.pos_dict | ||
target_src_pair = conn['gid_pairs'][src_gid] | ||
target_indeces = np.where(np.in1d(target_range, target_src_pair))[0] | ||
if src_gid is None: | ||
src_gid = valid_src_gids[0] | ||
_validate_type(src_gid, int, 'src_gid', 'int') | ||
|
||
src_idx = np.where(src_range == src_gid)[0][0] | ||
src_pos = src_type_pos[src_idx] | ||
if src_gid not in valid_src_gids: | ||
raise ValueError(f'src_gid {src_gid} not a valid cell ID for this ' | ||
f'connection. Please select one of {valid_src_gids}') | ||
|
||
# Aggregate positions and weight of each connected target | ||
weights, target_x_pos, target_y_pos = list(), list(), list() | ||
for target_idx in target_indeces: | ||
target_pos = target_type_pos[target_idx] | ||
target_x_pos.append(target_pos[0]) | ||
target_y_pos.append(target_pos[1]) | ||
weight, _ = _get_gaussian_connection(src_pos, target_pos, nc_dict) | ||
weights.append(weight) | ||
target_range = np.array(conn['target_range']) | ||
|
||
im = ax.scatter(target_x_pos, target_y_pos, c=weights, s=50, cmap=colormap) | ||
if axes is None: | ||
if src_type in net.cell_types: | ||
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True) | ||
else: | ||
fig, axes = plt.subplots(1, 1, sharex=True, sharey=True) | ||
axes = [axes] | ||
|
||
# Gather positions of all gids in target_type. | ||
x_pos = [target_type_pos[idx][0] for idx in range(len(target_type_pos))] | ||
y_pos = [target_type_pos[idx][1] for idx in range(len(target_type_pos))] | ||
ax.scatter(x_pos, y_pos, color='k', marker='x', zorder=-1, s=20) | ||
if len(axes) == 2: | ||
ax_src, ax = axes | ||
else: | ||
ax = axes[0] | ||
|
||
im = _update_target_plot(ax, conn, src_gid, src_type_pos, | ||
target_type_pos, src_range, | ||
target_range, nc_dict, colormap) | ||
|
||
# Only plot src_gid if proper cell type. | ||
x_src = src_type_pos[:, 0] | ||
y_src = src_type_pos[:, 1] | ||
x_src_valid = src_pos_valid[:, 0] | ||
y_src_valid = src_pos_valid[:, 1] | ||
if src_type in net.cell_types: | ||
ax.scatter(src_pos[0], src_pos[1], marker='s', color='red', s=150) | ||
ax.set_ylabel('Y Position') | ||
ax.set_xlabel('X Position') | ||
ax.set_title(f"{conn['src_type']}-> {conn['target_type']}" | ||
ax_src.scatter(x_src, y_src, marker='s', color='red', s=50, | ||
alpha=0.2) | ||
ax_src.scatter(x_src_valid, y_src_valid, marker='s', color='red', | ||
s=50) | ||
|
||
plt.suptitle(f"{conn['src_type']}-> {conn['target_type']}" | ||
f" ({conn['loc']}, {conn['receptor']})") | ||
|
||
def _onclick(event): | ||
if event.inaxes in [ax] or event.inaxes is None: | ||
return | ||
|
||
dist = np.linalg.norm(src_type_pos[:, :2] - | ||
np.array([event.xdata, event.ydata]), | ||
axis=1) | ||
src_idx = np.argmin(dist) | ||
|
||
src_gid = src_range[src_idx] | ||
if src_gid not in valid_src_gids: | ||
return | ||
_update_target_plot(ax, conn, src_gid, src_type_pos, | ||
target_type_pos, src_range, target_range, | ||
nc_dict, colormap) | ||
|
||
fig.canvas.draw() | ||
|
||
if colorbar: | ||
fig = ax.get_figure() | ||
xfmt = ScalarFormatter() | ||
|
@@ -941,5 +997,8 @@ def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, | |
cbar.ax.set_ylabel('Weight', rotation=-90, va="bottom") | ||
|
||
plt.tight_layout() | ||
|
||
fig.canvas.mpl_connect('button_press_event', _onclick) | ||
|
||
plt_show(show) | ||
return ax.get_figure(), ax | ||
return ax.get_figure() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the deletion necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not strictly necessary but I thought this would be good practice, otherwise the same "kind of connection" would occur in two different
net.connectivity
elements and the net effect would be a higher weight than we intended?