From 5ceeed9a96b13a2abc734ec6bd21eacd91e06c9c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 31 Jul 2020 09:50:53 +0200 Subject: [PATCH] Show electrode label on electrode widget upon clicking --- .../electrodegeometrywidget.py | 61 ++++++++---- .../widgets/mapswidget/activitymapwidget.py | 12 ++- .../widgets/mapswidget/templatemapswidget.py | 2 +- spikewidgets/widgets/mapswidget/utils.py | 47 --------- spikewidgets/widgets/utils.py | 95 +++++++++++++++++++ 5 files changed, 143 insertions(+), 74 deletions(-) delete mode 100644 spikewidgets/widgets/mapswidget/utils.py create mode 100644 spikewidgets/widgets/utils.py diff --git a/spikewidgets/widgets/electrodegeometrywidget/electrodegeometrywidget.py b/spikewidgets/widgets/electrodegeometrywidget/electrodegeometrywidget.py index 94e01be..e881814 100644 --- a/spikewidgets/widgets/electrodegeometrywidget/electrodegeometrywidget.py +++ b/spikewidgets/widgets/electrodegeometrywidget/electrodegeometrywidget.py @@ -1,8 +1,10 @@ import numpy as np +from matplotlib.patches import Ellipse +from ..utils import LabeledEllipse from spikewidgets.widgets.basewidget import BaseWidget -def plot_electrode_geometry(recording, markersize=20, marker='o', figure=None, ax=None): +def plot_electrode_geometry(recording, color='C0', label_color='r', figure=None, ax=None): """ Plots electrode geometry. @@ -10,10 +12,10 @@ def plot_electrode_geometry(recording, markersize=20, marker='o', figure=None, a ---------- recording: RecordingExtractor The recordng extractor object - markersize: int - The size of the marker for the electrodes - marker: str - The matplotlib marker to use (default 'o') + color: matplotlib color + The color of the electrodes + label_color: matplotlib color + The color of the channel label when clicking figure: matplotlib figure The figure to be used. If not given a figure is created ax: matplotlib axis @@ -26,8 +28,8 @@ def plot_electrode_geometry(recording, markersize=20, marker='o', figure=None, a """ W = ElectrodeGeometryWidget( recording=recording, - markersize=markersize, - marker=marker, + color=color, + label_color=label_color, figure=figure, ax=ax ) @@ -36,34 +38,51 @@ def plot_electrode_geometry(recording, markersize=20, marker='o', figure=None, a class ElectrodeGeometryWidget(BaseWidget): - def __init__(self, *, recording, markersize=10, marker='o', figure=None, ax=None): + def __init__(self, *, recording, color='C0', label_color='r', figure=None, ax=None): if 'location' not in recording.get_shared_channel_property_names(): raise AttributeError("'location' not found as a property") BaseWidget.__init__(self, figure, ax) self._recording = recording - self._ms = markersize - self._mark = marker + self._color = color + self._label_color = label_color self.name = 'ElectrodeGeometry' def plot(self, width=4, height=4): self._do_plot(width=width, height=height) def _do_plot(self, width, height): - geom = np.array(self._recording.get_channel_locations()) + locations = np.array(self._recording.get_channel_locations()) self.ax.axis('off') - x = geom[:, 0] - y = geom[:, 1] - xmin = np.min(x) - xmax = np.max(x) - ymin = np.min(y) - ymax = np.max(y) + x = locations[:, 0] + y = locations[:, 1] + x_un = np.unique(x) + y_un = np.unique(y) - margin = np.maximum(xmax - xmin, ymax - ymin) * 0.2 + if len(y_un) == 1: + pitch_x = np.min(np.diff(x_un)) + pitch_y = pitch_x + elif len(x_un) == 1: + pitch_y = np.min(np.diff(y_un)) + pitch_x = pitch_y + else: + pitch_x = np.min(np.diff(x_un)) + pitch_y = np.min(np.diff(y_un)) + + self._drs = [] + elec_x = 0.9 * pitch_x + elec_y = 0.9 * pitch_y + for (loc, ch) in zip(locations, self._recording.get_channel_ids()): + ell = Ellipse((loc[0] - elec_x / 2, loc[1] - elec_y / 2), elec_x, elec_y, + color=self._color, alpha=0.9) + self.ax.add_patch(ell) + dr = LabeledEllipse(ell, ch, self._label_color) + dr.connect() + self._drs.append(dr) - self.ax.scatter(x, y, marker=self._mark, s=int(self._ms)) self.ax.axis('equal') self.ax.set_xticks([]) self.ax.set_yticks([]) - self.ax.set_xlim(xmin - margin, xmax + margin) - self.ax.set_ylim(ymin - margin, ymax + margin) + self.ax.set_xlim(np.min(x) - pitch_x, np.max(x) + pitch_x) + self.ax.set_ylim(np.min(y) - pitch_y, np.max(y) + pitch_y) + diff --git a/spikewidgets/widgets/mapswidget/activitymapwidget.py b/spikewidgets/widgets/mapswidget/activitymapwidget.py index 7c50b45..0814480 100644 --- a/spikewidgets/widgets/mapswidget/activitymapwidget.py +++ b/spikewidgets/widgets/mapswidget/activitymapwidget.py @@ -1,11 +1,11 @@ import numpy as np import spiketoolkit as st import matplotlib.pylab as plt -from .utils import LabeledRectangle +from ..utils import LabeledRectangle from spikewidgets.widgets.basewidget import BaseWidget -def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', background='on', label_color='r', +def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', background='on', label_color='r', transpose=False, frame=False, ax=None, figure=None): """ Plots spike rate (estimated using simple threshold detector) as 2D activity map. @@ -52,7 +52,8 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', class ActivityMapWidget(BaseWidget): - def __init__(self, recording, channel_ids, trange, cmap, background, label_color='r', transpose=False, frame=False, figure=None, ax=None): + def __init__(self, recording, channel_ids, trange, cmap, background, label_color='r', transpose=False, frame=False, + figure=None, ax=None): BaseWidget.__init__(self, figure, ax) self._recording = recording self._channel_ids = channel_ids @@ -81,7 +82,7 @@ def _do_plot(self): start_frame=self._trange[0], end_frame=self._trange[1]) if self._transpose: - locations = np.roll(locations,1,axis=1) + locations = np.roll(locations, 1, axis=1) x = locations[:, 0] y = locations[:, 1] @@ -121,7 +122,8 @@ def _do_plot(self): self.ax.set_xlim(np.min(x) - pitch_x, np.max(x) + pitch_x) self.ax.set_ylim(np.min(y) - pitch_y, np.max(y) + pitch_y) if self._frame: - rect = plt.Rectangle((np.min(x) - pitch_x, np.min(y) - pitch_y), np.max(x)-np.min(x) + 2*pitch_x, np.max(y) - np.min(y) + 2*pitch_y, fill=None, edgecolor='k') + rect = plt.Rectangle((np.min(x) - pitch_x, np.min(y) - pitch_y), np.max(x) - np.min(x) + 2 * pitch_x, + np.max(y) - np.min(y) + 2 * pitch_y, fill=None, edgecolor='k') self.ax.add_patch(rect) self.ax.axis('equal') self.ax.axis('off') diff --git a/spikewidgets/widgets/mapswidget/templatemapswidget.py b/spikewidgets/widgets/mapswidget/templatemapswidget.py index 47fa8c3..b479cf8 100644 --- a/spikewidgets/widgets/mapswidget/templatemapswidget.py +++ b/spikewidgets/widgets/mapswidget/templatemapswidget.py @@ -1,7 +1,7 @@ import numpy as np import spiketoolkit as st import matplotlib.pylab as plt -from .utils import LabeledRectangle +from ..utils import LabeledRectangle from spikewidgets.widgets.basewidget import BaseMultiWidget diff --git a/spikewidgets/widgets/mapswidget/utils.py b/spikewidgets/widgets/mapswidget/utils.py deleted file mode 100644 index 6d2e291..0000000 --- a/spikewidgets/widgets/mapswidget/utils.py +++ /dev/null @@ -1,47 +0,0 @@ - -class LabeledRectangle: - lock = None # only one can be animated at a time - - def __init__(self, rect, channel, color): - self.rect = rect - self.press = None - self.background = None - self.channel_str = str(channel) - axes = self.rect.axes - x0, y0 = self.rect.xy - self.text = axes.text(x0, y0, self.channel_str, color=color) - self.text.set_visible(False) - - def connect(self): - 'connect to all the events we need' - self.cidpress = self.rect.figure.canvas.mpl_connect('button_press_event', self.on_press) - self.cidrelease = self.rect.figure.canvas.mpl_connect('button_release_event', self.on_release) - - def on_press(self, event): - 'on button press we will see if the mouse is over us and store some data' - if event.inaxes != self.rect.axes: - return - if LabeledRectangle.lock is not None: - return - contains, attrd = self.rect.contains(event) - if not contains: return - x0, y0 = self.rect.xy - self.press = x0, y0, event.xdata, event.ydata - LabeledRectangle.lock = self - self.text.set_visible(True) - self.text.draw() - - def on_release(self, event): - 'on release we reset the press data' - if LabeledRectangle.lock is not self: - return - self.press = None - LabeledRectangle.lock = None - self.text.set_visible(False) - self.text.draw() - - def disconnect(self): - 'disconnect all the stored connection ids' - self.rect.figure.canvas.mpl_disconnect(self.cidpress) - self.rect.figure.canvas.mpl_disconnect(self.cidrelease) - diff --git a/spikewidgets/widgets/utils.py b/spikewidgets/widgets/utils.py new file mode 100644 index 0000000..35e8756 --- /dev/null +++ b/spikewidgets/widgets/utils.py @@ -0,0 +1,95 @@ + + + +class LabeledRectangle: + lock = None # only one can be animated at a time + + def __init__(self, rect, channel, color): + self.rect = rect + self.press = None + self.background = None + self.channel_str = str(channel) + axes = self.rect.axes + x0, y0 = self.rect.xy + self.text = axes.text(x0, y0, self.channel_str, color=color) + self.text.set_visible(False) + + def connect(self): + 'connect to all the events we need' + self.cidpress = self.rect.figure.canvas.mpl_connect('button_press_event', self.on_press) + self.cidrelease = self.rect.figure.canvas.mpl_connect('button_release_event', self.on_release) + + def on_press(self, event): + 'on button press we will see if the mouse is over us and store some data' + if event.inaxes != self.rect.axes: + return + if LabeledRectangle.lock is not None: + return + contains, attrd = self.rect.contains(event) + if not contains: return + x0, y0 = self.rect.xy + self.press = x0, y0, event.xdata, event.ydata + LabeledRectangle.lock = self + self.text.set_visible(True) + self.text.draw() + + def on_release(self, event): + 'on release we reset the press data' + if LabeledRectangle.lock is not self: + return + self.press = None + LabeledRectangle.lock = None + self.text.set_visible(False) + self.text.draw() + + def disconnect(self): + 'disconnect all the stored connection ids' + self.rect.figure.canvas.mpl_disconnect(self.cidpress) + self.rect.figure.canvas.mpl_disconnect(self.cidrelease) + + +class LabeledEllipse: + lock = None # only one can be animated at a time + + def __init__(self, ell, channel, color): + self.ell = ell + self.press = None + self.background = None + self.channel_str = str(channel) + axes = self.ell.axes + x0, y0 = self.ell.get_center() + self.text = axes.text(x0, y0, self.channel_str, color=color) + self.text.set_visible(False) + + def connect(self): + 'connect to all the events we need' + self.cidpress = self.ell.figure.canvas.mpl_connect('button_press_event', self.on_press) + self.cidrelease = self.ell.figure.canvas.mpl_connect('button_release_event', self.on_release) + + def on_press(self, event): + 'on button press we will see if the mouse is over us and store some data' + if event.inaxes != self.ell.axes: + return + if LabeledRectangle.lock is not None: + return + contains, attrd = self.ell.contains(event) + if not contains: return + x0, y0 = self.ell.get_center() + self.press = x0, y0, event.xdata, event.ydata + LabeledEllipse.lock = self + self.text.set_visible(True) + self.text.draw() + + def on_release(self, event): + 'on release we reset the press data' + if LabeledEllipse.lock is not self: + return + self.press = None + LabeledEllipse.lock = None + self.text.set_visible(False) + self.text.draw() + + def disconnect(self): + 'disconnect all the stored connection ids' + self.ell.figure.canvas.mpl_disconnect(self.cidpress) + self.ell.figure.canvas.mpl_disconnect(self.cidrelease)