Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #64 from SpikeInterface/interactive_electrode
Browse files Browse the repository at this point in the history
Show electrode label on electrode widget upon clicking
  • Loading branch information
alejoe91 authored Jul 31, 2020
2 parents 0a44d74 + 5ceeed9 commit 1f02ae8
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 74 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import numpy as np
from matplotlib.patches import Ellipse
from ..utils import LabeledEllipse
from spikewidgets.widgets.basewidget import BaseWidget


def plot_electrode_geometry(recording, markersize=20, marker='o', figure=None, ax=None):
def plot_electrode_geometry(recording, color='C0', label_color='r', figure=None, ax=None):
"""
Plots electrode geometry.
Parameters
----------
recording: RecordingExtractor
The recordng extractor object
markersize: int
The size of the marker for the electrodes
marker: str
The matplotlib marker to use (default 'o')
color: matplotlib color
The color of the electrodes
label_color: matplotlib color
The color of the channel label when clicking
figure: matplotlib figure
The figure to be used. If not given a figure is created
ax: matplotlib axis
Expand All @@ -26,8 +28,8 @@ def plot_electrode_geometry(recording, markersize=20, marker='o', figure=None, a
"""
W = ElectrodeGeometryWidget(
recording=recording,
markersize=markersize,
marker=marker,
color=color,
label_color=label_color,
figure=figure,
ax=ax
)
Expand All @@ -36,34 +38,51 @@ def plot_electrode_geometry(recording, markersize=20, marker='o', figure=None, a


class ElectrodeGeometryWidget(BaseWidget):
def __init__(self, *, recording, markersize=10, marker='o', figure=None, ax=None):
def __init__(self, *, recording, color='C0', label_color='r', figure=None, ax=None):
if 'location' not in recording.get_shared_channel_property_names():
raise AttributeError("'location' not found as a property")
BaseWidget.__init__(self, figure, ax)
self._recording = recording
self._ms = markersize
self._mark = marker
self._color = color
self._label_color = label_color
self.name = 'ElectrodeGeometry'

def plot(self, width=4, height=4):
self._do_plot(width=width, height=height)

def _do_plot(self, width, height):
geom = np.array(self._recording.get_channel_locations())
locations = np.array(self._recording.get_channel_locations())
self.ax.axis('off')

x = geom[:, 0]
y = geom[:, 1]
xmin = np.min(x)
xmax = np.max(x)
ymin = np.min(y)
ymax = np.max(y)
x = locations[:, 0]
y = locations[:, 1]
x_un = np.unique(x)
y_un = np.unique(y)

margin = np.maximum(xmax - xmin, ymax - ymin) * 0.2
if len(y_un) == 1:
pitch_x = np.min(np.diff(x_un))
pitch_y = pitch_x
elif len(x_un) == 1:
pitch_y = np.min(np.diff(y_un))
pitch_x = pitch_y
else:
pitch_x = np.min(np.diff(x_un))
pitch_y = np.min(np.diff(y_un))

self._drs = []
elec_x = 0.9 * pitch_x
elec_y = 0.9 * pitch_y
for (loc, ch) in zip(locations, self._recording.get_channel_ids()):
ell = Ellipse((loc[0] - elec_x / 2, loc[1] - elec_y / 2), elec_x, elec_y,
color=self._color, alpha=0.9)
self.ax.add_patch(ell)
dr = LabeledEllipse(ell, ch, self._label_color)
dr.connect()
self._drs.append(dr)

self.ax.scatter(x, y, marker=self._mark, s=int(self._ms))
self.ax.axis('equal')
self.ax.set_xticks([])
self.ax.set_yticks([])
self.ax.set_xlim(xmin - margin, xmax + margin)
self.ax.set_ylim(ymin - margin, ymax + margin)
self.ax.set_xlim(np.min(x) - pitch_x, np.max(x) + pitch_x)
self.ax.set_ylim(np.min(y) - pitch_y, np.max(y) + pitch_y)

12 changes: 7 additions & 5 deletions spikewidgets/widgets/mapswidget/activitymapwidget.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numpy as np
import spiketoolkit as st
import matplotlib.pylab as plt
from .utils import LabeledRectangle
from ..utils import LabeledRectangle
from spikewidgets.widgets.basewidget import BaseWidget


def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', background='on', label_color='r',
def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', background='on', label_color='r',
transpose=False, frame=False, ax=None, figure=None):
"""
Plots spike rate (estimated using simple threshold detector) as 2D activity map.
Expand Down Expand Up @@ -52,7 +52,8 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis',

class ActivityMapWidget(BaseWidget):

def __init__(self, recording, channel_ids, trange, cmap, background, label_color='r', transpose=False, frame=False, figure=None, ax=None):
def __init__(self, recording, channel_ids, trange, cmap, background, label_color='r', transpose=False, frame=False,
figure=None, ax=None):
BaseWidget.__init__(self, figure, ax)
self._recording = recording
self._channel_ids = channel_ids
Expand Down Expand Up @@ -81,7 +82,7 @@ def _do_plot(self):
start_frame=self._trange[0],
end_frame=self._trange[1])
if self._transpose:
locations = np.roll(locations,1,axis=1)
locations = np.roll(locations, 1, axis=1)

x = locations[:, 0]
y = locations[:, 1]
Expand Down Expand Up @@ -121,7 +122,8 @@ def _do_plot(self):
self.ax.set_xlim(np.min(x) - pitch_x, np.max(x) + pitch_x)
self.ax.set_ylim(np.min(y) - pitch_y, np.max(y) + pitch_y)
if self._frame:
rect = plt.Rectangle((np.min(x) - pitch_x, np.min(y) - pitch_y), np.max(x)-np.min(x) + 2*pitch_x, np.max(y) - np.min(y) + 2*pitch_y, fill=None, edgecolor='k')
rect = plt.Rectangle((np.min(x) - pitch_x, np.min(y) - pitch_y), np.max(x) - np.min(x) + 2 * pitch_x,
np.max(y) - np.min(y) + 2 * pitch_y, fill=None, edgecolor='k')
self.ax.add_patch(rect)
self.ax.axis('equal')
self.ax.axis('off')
2 changes: 1 addition & 1 deletion spikewidgets/widgets/mapswidget/templatemapswidget.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import spiketoolkit as st
import matplotlib.pylab as plt
from .utils import LabeledRectangle
from ..utils import LabeledRectangle
from spikewidgets.widgets.basewidget import BaseMultiWidget


Expand Down
47 changes: 0 additions & 47 deletions spikewidgets/widgets/mapswidget/utils.py

This file was deleted.

95 changes: 95 additions & 0 deletions spikewidgets/widgets/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@



class LabeledRectangle:
lock = None # only one can be animated at a time

def __init__(self, rect, channel, color):
self.rect = rect
self.press = None
self.background = None
self.channel_str = str(channel)
axes = self.rect.axes
x0, y0 = self.rect.xy
self.text = axes.text(x0, y0, self.channel_str, color=color)
self.text.set_visible(False)

def connect(self):
'connect to all the events we need'
self.cidpress = self.rect.figure.canvas.mpl_connect('button_press_event', self.on_press)
self.cidrelease = self.rect.figure.canvas.mpl_connect('button_release_event', self.on_release)

def on_press(self, event):
'on button press we will see if the mouse is over us and store some data'
if event.inaxes != self.rect.axes:
return
if LabeledRectangle.lock is not None:
return
contains, attrd = self.rect.contains(event)
if not contains: return
x0, y0 = self.rect.xy
self.press = x0, y0, event.xdata, event.ydata
LabeledRectangle.lock = self
self.text.set_visible(True)
self.text.draw()

def on_release(self, event):
'on release we reset the press data'
if LabeledRectangle.lock is not self:
return
self.press = None
LabeledRectangle.lock = None
self.text.set_visible(False)
self.text.draw()

def disconnect(self):
'disconnect all the stored connection ids'
self.rect.figure.canvas.mpl_disconnect(self.cidpress)
self.rect.figure.canvas.mpl_disconnect(self.cidrelease)


class LabeledEllipse:
lock = None # only one can be animated at a time

def __init__(self, ell, channel, color):
self.ell = ell
self.press = None
self.background = None
self.channel_str = str(channel)
axes = self.ell.axes
x0, y0 = self.ell.get_center()
self.text = axes.text(x0, y0, self.channel_str, color=color)
self.text.set_visible(False)

def connect(self):
'connect to all the events we need'
self.cidpress = self.ell.figure.canvas.mpl_connect('button_press_event', self.on_press)
self.cidrelease = self.ell.figure.canvas.mpl_connect('button_release_event', self.on_release)

def on_press(self, event):
'on button press we will see if the mouse is over us and store some data'
if event.inaxes != self.ell.axes:
return
if LabeledRectangle.lock is not None:
return
contains, attrd = self.ell.contains(event)
if not contains: return
x0, y0 = self.ell.get_center()
self.press = x0, y0, event.xdata, event.ydata
LabeledEllipse.lock = self
self.text.set_visible(True)
self.text.draw()

def on_release(self, event):
'on release we reset the press data'
if LabeledEllipse.lock is not self:
return
self.press = None
LabeledEllipse.lock = None
self.text.set_visible(False)
self.text.draw()

def disconnect(self):
'disconnect all the stored connection ids'
self.ell.figure.canvas.mpl_disconnect(self.cidpress)
self.ell.figure.canvas.mpl_disconnect(self.cidrelease)

0 comments on commit 1f02ae8

Please sign in to comment.