Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] ENH: interactive connectivity #376

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Changelog

- Add new function :func:`~hnn_core.viz.plot_cell_morphology` to visualize cell morphology, by `Mainak Jas`_ in `#319 <https://github.com/jonescompneurolab/hnn-core/pull/319>`_

- Compute dipole component in z-direction automatically from cell morphology instead of hard coding, by `Mainak Jas`_ in `#327 <https://github.com/jonescompneurolab/hnn-core/pull/320>`_
- Compute dipole component in z-direction automatically from cell morphology instead of hard coding, by `Mainak Jas`_ in `#327 <https://github.com/jonescompneurolab/hnn-core/pull/327>`_

- Store :class:`~hnn_core.Cell` instances in :class:`~hnn_core.Network`'s :attr:`~/hnn_core.Network.cells` attribute by `Ryan Thorpe`_ in `#321 <https://github.com/jonescompneurolab/hnn-core/pull/321>`_

Expand All @@ -32,6 +32,8 @@ Changelog

- Previously published models can now be loaded via ``net=law_2021_model()`` and ``jones_2009_model()``, by `Nick Tolley`_ in `#348 <https://github.com/jonescompneurolab/hnn-core/pull/348>`_

- Add ability to interactivity explore connections in :func:`~hnn_core.viz.plot_cell_connectivity` by `Mainak Jas`_ in `#376 <https://github.com/jonescompneurolab/hnn-core/pull/376>`_

Bug
~~~

Expand Down
2 changes: 1 addition & 1 deletion examples/howto/plot_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

gid_idx = 11
src_gid = net_erp.connectivity[conn_idx]['src_range'][gid_idx]
fig, ax = plot_cell_connectivity(net_erp, conn_idx, src_gid)
fig = plot_cell_connectivity(net_erp, conn_idx, src_gid)

###############################################################################
# Data recorded during simulations are stored under
Expand Down
22 changes: 21 additions & 1 deletion hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
import os.path as op

import matplotlib
Expand All @@ -13,6 +14,13 @@
matplotlib.use('agg')


def _fake_click(fig, ax, point, button=1):
"""Fake a click at a point within axes."""
x, y = ax.transData.transform_point(point)
func = partial(fig.canvas.button_press_event, x=x, y=y, button=button)
func(guiEvent=None)


def test_network_visualization():
"""Test network visualisations."""
hnn_core_root = op.dirname(hnn_core.__file__)
Expand Down Expand Up @@ -47,9 +55,21 @@ def test_network_visualization():
with pytest.raises(TypeError, match='src_gid must be an instance of'):
plot_cell_connectivity(net, conn_idx, src_gid='blah')

with pytest.raises(ValueError, match='src_gid not in the'):
with pytest.raises(ValueError, match='src_gid not a valid cell ID'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with pytest.raises(ValueError, match='src_gid not a valid cell ID'):
with pytest.raises(ValueError, match='src_gid -1 not a valid cell ID'):

plot_cell_connectivity(net, conn_idx, src_gid=-1)

# smoke test interactive clicking
del net.connectivity[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the deletion necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not strictly necessary but I thought this would be good practice, otherwise the same "kind of connection" would occur in two different net.connectivity elements and the net effect would be a higher weight than we intended?

conn_idx = 15
net.add_connection(net.gid_ranges['L2_pyramidal'][::2],
'L5_basket', 'soma',
'ampa', 0.00025, 1.0, lamtha=3.0,
probability=0.8)
fig = plot_cell_connectivity(net, conn_idx)
ax_src = fig.axes[0]
pos = net.pos_dict['L2_pyramidal'][2]
_fake_click(fig, ax_src, [pos[0], pos[1]])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test to make sure the figure is actually updated? A really simple one would be to grab the title with fig.texts and assert that it is different after running _fake_click.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're making me work harder for my merges ;-) It's good! Take a look at the last commit


def test_dipole_visualization():
"""Test dipole visualisations."""
Expand Down
139 changes: 97 additions & 42 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,8 +836,41 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True,
return ax.get_figure()


def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True,
colormap='viridis', show=True):
def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos,
src_range, target_range, nc_dict, colormap):
from .cell import _get_gaussian_connection

# Extract indeces to get position in network
# Index in gid range aligns with net.pos_dict
target_src_pair = conn['gid_pairs'][src_gid]
target_indeces = np.where(np.in1d(target_range, target_src_pair))[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great solution, much cleaner

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's your code, I didn't do anything ... just moved it around :)


src_idx = np.where(src_range == src_gid)[0][0]
src_pos = src_type_pos[src_idx]

# Aggregate positions and weight of each connected target
weights, target_x_pos, target_y_pos = list(), list(), list()
for target_idx in target_indeces:
target_pos = target_type_pos[target_idx]
target_x_pos.append(target_pos[0])
target_y_pos.append(target_pos[1])
weight, _ = _get_gaussian_connection(src_pos, target_pos, nc_dict)
weights.append(weight)

ax.clear()
im = ax.scatter(target_x_pos, target_y_pos, c=weights, s=50,
cmap=colormap)
x_pos = target_type_pos[:, 0]
y_pos = target_type_pos[:, 1]
ax.scatter(x_pos, y_pos, color='k', marker='x', zorder=-1, s=20)
ax.scatter(src_pos[0], src_pos[1], marker='s', color='red', s=150)
ax.set_ylabel('Y Position')
ax.set_xlabel('X Position')
return im


def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring to indicate interactive plot? Also it'd be good to add a notes section describing how to use it.

colorbar=True, colormap='viridis', show=True):
"""Plot synaptic weight of connections originating from src_gid.

Parameters
Expand All @@ -847,9 +880,11 @@ def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True,
conn_idx : int
Index of connection to be visualized
from `net.connectivity`
src_gid : int
Each cell in a network is uniquely identified by it's "global ID": GID.
ax : instance of Axes3D
src_gid : int | None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good description, but after writing so many docstrings it's making me think it will be worthwhile creating some consistency for the arguments that appear super often. Sort of in line with the hnn-core glossary we had talked about, these parameters could start with the boilerplate description, and then more function specific text.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will defer this to a separate PR: #387 :)

The cell ID of the source cell. It must be an element of
net.connectivity[conn_idx]['gid_pairs'].keys()
If None, the first src_gid from the list of valid src_gids is selected.
axes : instance of Axes3D
Matplotlib 3D axis
colormap : str
The name of a matplotlib colormap. Default: 'viridis'
Expand All @@ -875,63 +910,80 @@ def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True,
"""
import matplotlib.pyplot as plt
from .network import Network
from .cell import _get_gaussian_connection
from matplotlib.ticker import ScalarFormatter

_validate_type(net, Network, 'net', 'Network')
_validate_type(conn_idx, int, 'conn_idx', 'int')
_validate_type(src_gid, int, 'src_gid', 'int')
if ax is None:
_, ax = plt.subplots(1, 1)

# Load objects for distance calculation
conn = net.connectivity[conn_idx]
nc_dict = conn['nc_dict']
src_type = conn['src_type']
target_type = conn['target_type']
src_type_pos = net.pos_dict[src_type]
target_type_pos = net.pos_dict[target_type]

src_type_pos = np.array(net.pos_dict[src_type])
target_type_pos = np.array(net.pos_dict[target_type])
src_range = np.array(conn['src_range'])
if src_gid not in src_range:
raise ValueError(f'src_gid not in the src type range of {src_type} '
f'gids. Valid gids include {src_range[0]} -> '
f'{src_range[-1]}')

target_range = np.array(conn['target_range'])
valid_src_gids = list(net.connectivity[conn_idx]['gid_pairs'].keys())
src_pos_valid = src_type_pos[np.in1d(src_range, valid_src_gids)]

# Extract indeces to get position in network
# Index in gid range aligns with net.pos_dict
target_src_pair = conn['gid_pairs'][src_gid]
target_indeces = np.where(np.in1d(target_range, target_src_pair))[0]
if src_gid is None:
src_gid = valid_src_gids[0]
_validate_type(src_gid, int, 'src_gid', 'int')

src_idx = np.where(src_range == src_gid)[0][0]
src_pos = src_type_pos[src_idx]
if src_gid not in valid_src_gids:
raise ValueError(f'src_gid not a valid cell ID for this connection '
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError(f'src_gid not a valid cell ID for this connection '
raise ValueError(f'src_gid {src_gid} not a valid cell ID for this connection '

f'Please select one of {valid_src_gids}')

# Aggregate positions and weight of each connected target
weights, target_x_pos, target_y_pos = list(), list(), list()
for target_idx in target_indeces:
target_pos = target_type_pos[target_idx]
target_x_pos.append(target_pos[0])
target_y_pos.append(target_pos[1])
weight, _ = _get_gaussian_connection(src_pos, target_pos, nc_dict)
weights.append(weight)
target_range = np.array(conn['target_range'])

im = ax.scatter(target_x_pos, target_y_pos, c=weights, s=50, cmap=colormap)
if axes is None:
if src_type in net.cell_types:
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True)
else:
fig, axes = plt.subplots(1, 1, sharex=True, sharey=True)
axes = [axes]

# Gather positions of all gids in target_type.
x_pos = [target_type_pos[idx][0] for idx in range(len(target_type_pos))]
y_pos = [target_type_pos[idx][1] for idx in range(len(target_type_pos))]
ax.scatter(x_pos, y_pos, color='k', marker='x', zorder=-1, s=20)
if len(axes) == 2:
ax_src, ax = axes
else:
ax = axes[0]

# Only plot src_gid if proper cell type.
im = _update_target_plot(ax, conn, src_gid, src_type_pos,
target_type_pos, src_range,
target_range, nc_dict, colormap)

x_src = src_type_pos[:, 0]
y_src = src_type_pos[:, 1]
x_src_valid = src_pos_valid[:, 0]
y_src_valid = src_pos_valid[:, 1]
if src_type in net.cell_types:
ax.scatter(src_pos[0], src_pos[1], marker='s', color='red', s=150)
ax.set_ylabel('Y Position')
ax.set_xlabel('X Position')
ax.set_title(f"{conn['src_type']}-> {conn['target_type']}"
ax_src.scatter(x_src, y_src, marker='s', color='red', s=50,
alpha=0.2)
ax_src.scatter(x_src_valid, y_src_valid, marker='s', color='red',
s=50)

plt.suptitle(f"{conn['src_type']}-> {conn['target_type']}"
f" ({conn['loc']}, {conn['receptor']})")

def _onclick(event):
if event.inaxes in [ax] or event.inaxes is None:
return

dist = np.linalg.norm(src_type_pos[:, :2] -
np.array([event.xdata, event.ydata]),
axis=1)
src_idx = np.argmin(dist)

src_gid = src_range[src_idx]
if src_gid not in valid_src_gids:
return
_update_target_plot(ax, conn, src_gid, src_type_pos,
target_type_pos, src_range, target_range,
nc_dict, colormap)

fig.canvas.draw()

if colorbar:
fig = ax.get_figure()
xfmt = ScalarFormatter()
Expand All @@ -941,5 +993,8 @@ def plot_cell_connectivity(net, conn_idx, src_gid, ax=None, colorbar=True,
cbar.ax.set_ylabel('Weight', rotation=-90, va="bottom")

plt.tight_layout()

fig.canvas.mpl_connect('button_press_event', _onclick)

plt_show(show)
return ax.get_figure(), ax
return ax.get_figure()