-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] ENH: interactive connectivity #376
Changes from 5 commits
490ac1f
39ad1bc
d3c7c92
9380d5c
da2b9e7
b1fe671
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from functools import partial | ||
import os.path as op | ||
|
||
import matplotlib | ||
|
@@ -13,6 +14,13 @@ | |
matplotlib.use('agg') | ||
|
||
|
||
def _fake_click(fig, ax, point, button=1): | ||
"""Fake a click at a point within axes.""" | ||
x, y = ax.transData.transform_point(point) | ||
func = partial(fig.canvas.button_press_event, x=x, y=y, button=button) | ||
func(guiEvent=None) | ||
|
||
|
||
def test_network_visualization(): | ||
"""Test network visualisations.""" | ||
hnn_core_root = op.dirname(hnn_core.__file__) | ||
|
@@ -47,9 +55,21 @@ def test_network_visualization(): | |
with pytest.raises(TypeError, match='src_gid must be an instance of'): | ||
plot_cell_connectivity(net, conn_idx, src_gid='blah') | ||
|
||
with pytest.raises(ValueError, match='src_gid not in the'): | ||
with pytest.raises(ValueError, match='src_gid not a valid cell ID'): | ||
plot_cell_connectivity(net, conn_idx, src_gid=-1) | ||
|
||
# smoke test interactive clicking | ||
del net.connectivity[-1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is the deletion necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not strictly necessary but I thought this would be good practice, otherwise the same "kind of connection" would occur in two different |
||
conn_idx = 15 | ||
net.add_connection(net.gid_ranges['L2_pyramidal'][::2], | ||
'L5_basket', 'soma', | ||
'ampa', 0.00025, 1.0, lamtha=3.0, | ||
probability=0.8) | ||
fig = plot_cell_connectivity(net, conn_idx) | ||
ax_src = fig.axes[0] | ||
pos = net.pos_dict['L2_pyramidal'][2] | ||
_fake_click(fig, ax_src, [pos[0], pos[1]]) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test to make sure the figure is actually updated? A really simple one would be to grab the title with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you're making me work harder for my merges ;-) It's good! Take a look at the last commit |
||
|
||
def test_dipole_visualization(): | ||
"""Test dipole visualisations.""" | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -836,8 +836,41 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, | |||||
return ax.get_figure() | ||||||
|
||||||
|
||||||
def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, | ||||||
colormap='viridis', show=True): | ||||||
def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, | ||||||
src_range, target_range, nc_dict, colormap): | ||||||
from .cell import _get_gaussian_connection | ||||||
|
||||||
# Extract indeces to get position in network | ||||||
# Index in gid range aligns with net.pos_dict | ||||||
target_src_pair = conn['gid_pairs'][src_gid] | ||||||
target_indeces = np.where(np.in1d(target_range, target_src_pair))[0] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great solution, much cleaner There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's your code, I didn't do anything ... just moved it around :) |
||||||
|
||||||
src_idx = np.where(src_range == src_gid)[0][0] | ||||||
src_pos = src_type_pos[src_idx] | ||||||
|
||||||
# Aggregate positions and weight of each connected target | ||||||
weights, target_x_pos, target_y_pos = list(), list(), list() | ||||||
for target_idx in target_indeces: | ||||||
target_pos = target_type_pos[target_idx] | ||||||
target_x_pos.append(target_pos[0]) | ||||||
target_y_pos.append(target_pos[1]) | ||||||
weight, _ = _get_gaussian_connection(src_pos, target_pos, nc_dict) | ||||||
weights.append(weight) | ||||||
|
||||||
ax.clear() | ||||||
im = ax.scatter(target_x_pos, target_y_pos, c=weights, s=50, | ||||||
cmap=colormap) | ||||||
x_pos = target_type_pos[:, 0] | ||||||
y_pos = target_type_pos[:, 1] | ||||||
ax.scatter(x_pos, y_pos, color='k', marker='x', zorder=-1, s=20) | ||||||
ax.scatter(src_pos[0], src_pos[1], marker='s', color='red', s=150) | ||||||
ax.set_ylabel('Y Position') | ||||||
ax.set_xlabel('X Position') | ||||||
return im | ||||||
|
||||||
|
||||||
def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update docstring to indicate interactive plot? Also it'd be good to add a notes section describing how to use it. |
||||||
colorbar=True, colormap='viridis', show=True): | ||||||
"""Plot synaptic weight of connections originating from src_gid. | ||||||
|
||||||
Parameters | ||||||
|
@@ -847,9 +880,11 @@ def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, | |||||
conn_idx : int | ||||||
Index of connection to be visualized | ||||||
from `net.connectivity` | ||||||
src_gid : int | ||||||
Each cell in a network is uniquely identified by it's "global ID": GID. | ||||||
ax : instance of Axes3D | ||||||
src_gid : int | None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good description, but after writing so many docstrings it's making me think it will be worthwhile creating some consistency for the arguments that appear super often. Sort of in line with the hnn-core glossary we had talked about, these parameters could start with the boilerplate description, and then more function specific text. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will defer this to a separate PR: #387 :) |
||||||
The cell ID of the source cell. It must be an element of | ||||||
net.connectivity[conn_idx]['gid_pairs'].keys() | ||||||
If None, the first src_gid from the list of valid src_gids is selected. | ||||||
axes : instance of Axes3D | ||||||
Matplotlib 3D axis | ||||||
colormap : str | ||||||
The name of a matplotlib colormap. Default: 'viridis' | ||||||
|
@@ -875,63 +910,80 @@ def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, | |||||
""" | ||||||
import matplotlib.pyplot as plt | ||||||
from .network import Network | ||||||
from .cell import _get_gaussian_connection | ||||||
from matplotlib.ticker import ScalarFormatter | ||||||
|
||||||
_validate_type(net, Network, 'net', 'Network') | ||||||
_validate_type(conn_idx, int, 'conn_idx', 'int') | ||||||
_validate_type(src_gid, int, 'src_gid', 'int') | ||||||
if ax is None: | ||||||
_, ax = plt.subplots(1, 1) | ||||||
|
||||||
# Load objects for distance calculation | ||||||
conn = net.connectivity[conn_idx] | ||||||
nc_dict = conn['nc_dict'] | ||||||
src_type = conn['src_type'] | ||||||
target_type = conn['target_type'] | ||||||
src_type_pos = net.pos_dict[src_type] | ||||||
target_type_pos = net.pos_dict[target_type] | ||||||
|
||||||
src_type_pos = np.array(net.pos_dict[src_type]) | ||||||
target_type_pos = np.array(net.pos_dict[target_type]) | ||||||
src_range = np.array(conn['src_range']) | ||||||
if src_gid not in src_range: | ||||||
raise ValueError(f'src_gid not in the src type range of {src_type} ' | ||||||
f'gids. Valid gids include {src_range[0]} -> ' | ||||||
f'{src_range[-1]}') | ||||||
|
||||||
target_range = np.array(conn['target_range']) | ||||||
valid_src_gids = list(net.connectivity[conn_idx]['gid_pairs'].keys()) | ||||||
src_pos_valid = src_type_pos[np.in1d(src_range, valid_src_gids)] | ||||||
|
||||||
# Extract indeces to get position in network | ||||||
# Index in gid range aligns with net.pos_dict | ||||||
target_src_pair = conn['gid_pairs'][src_gid] | ||||||
target_indeces = np.where(np.in1d(target_range, target_src_pair))[0] | ||||||
if src_gid is None: | ||||||
src_gid = valid_src_gids[0] | ||||||
_validate_type(src_gid, int, 'src_gid', 'int') | ||||||
|
||||||
src_idx = np.where(src_range == src_gid)[0][0] | ||||||
src_pos = src_type_pos[src_idx] | ||||||
if src_gid not in valid_src_gids: | ||||||
raise ValueError(f'src_gid not a valid cell ID for this connection ' | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
f'Please select one of {valid_src_gids}') | ||||||
|
||||||
# Aggregate positions and weight of each connected target | ||||||
weights, target_x_pos, target_y_pos = list(), list(), list() | ||||||
for target_idx in target_indeces: | ||||||
target_pos = target_type_pos[target_idx] | ||||||
target_x_pos.append(target_pos[0]) | ||||||
target_y_pos.append(target_pos[1]) | ||||||
weight, _ = _get_gaussian_connection(src_pos, target_pos, nc_dict) | ||||||
weights.append(weight) | ||||||
target_range = np.array(conn['target_range']) | ||||||
|
||||||
im = ax.scatter(target_x_pos, target_y_pos, c=weights, s=50, cmap=colormap) | ||||||
if axes is None: | ||||||
if src_type in net.cell_types: | ||||||
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True) | ||||||
else: | ||||||
fig, axes = plt.subplots(1, 1, sharex=True, sharey=True) | ||||||
axes = [axes] | ||||||
|
||||||
# Gather positions of all gids in target_type. | ||||||
x_pos = [target_type_pos[idx][0] for idx in range(len(target_type_pos))] | ||||||
y_pos = [target_type_pos[idx][1] for idx in range(len(target_type_pos))] | ||||||
ax.scatter(x_pos, y_pos, color='k', marker='x', zorder=-1, s=20) | ||||||
if len(axes) == 2: | ||||||
ax_src, ax = axes | ||||||
else: | ||||||
ax = axes[0] | ||||||
|
||||||
# Only plot src_gid if proper cell type. | ||||||
im = _update_target_plot(ax, conn, src_gid, src_type_pos, | ||||||
target_type_pos, src_range, | ||||||
target_range, nc_dict, colormap) | ||||||
|
||||||
x_src = src_type_pos[:, 0] | ||||||
y_src = src_type_pos[:, 1] | ||||||
x_src_valid = src_pos_valid[:, 0] | ||||||
y_src_valid = src_pos_valid[:, 1] | ||||||
if src_type in net.cell_types: | ||||||
ax.scatter(src_pos[0], src_pos[1], marker='s', color='red', s=150) | ||||||
ax.set_ylabel('Y Position') | ||||||
ax.set_xlabel('X Position') | ||||||
ax.set_title(f"{conn['src_type']}-> {conn['target_type']}" | ||||||
ax_src.scatter(x_src, y_src, marker='s', color='red', s=50, | ||||||
alpha=0.2) | ||||||
ax_src.scatter(x_src_valid, y_src_valid, marker='s', color='red', | ||||||
s=50) | ||||||
|
||||||
plt.suptitle(f"{conn['src_type']}-> {conn['target_type']}" | ||||||
f" ({conn['loc']}, {conn['receptor']})") | ||||||
|
||||||
def _onclick(event): | ||||||
if event.inaxes in [ax] or event.inaxes is None: | ||||||
return | ||||||
|
||||||
dist = np.linalg.norm(src_type_pos[:, :2] - | ||||||
np.array([event.xdata, event.ydata]), | ||||||
axis=1) | ||||||
src_idx = np.argmin(dist) | ||||||
|
||||||
src_gid = src_range[src_idx] | ||||||
if src_gid not in valid_src_gids: | ||||||
return | ||||||
_update_target_plot(ax, conn, src_gid, src_type_pos, | ||||||
target_type_pos, src_range, target_range, | ||||||
nc_dict, colormap) | ||||||
|
||||||
fig.canvas.draw() | ||||||
|
||||||
if colorbar: | ||||||
fig = ax.get_figure() | ||||||
xfmt = ScalarFormatter() | ||||||
|
@@ -941,5 +993,8 @@ def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True, | |||||
cbar.ax.set_ylabel('Weight', rotation=-90, va="bottom") | ||||||
|
||||||
plt.tight_layout() | ||||||
|
||||||
fig.canvas.mpl_connect('button_press_event', _onclick) | ||||||
|
||||||
plt_show(show) | ||||||
return ax.get_figure(), ax | ||||||
return ax.get_figure() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.