From 4199b9c8479aa9d11dab813499f15a5543d04376 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 18 Nov 2020 12:51:21 +0100 Subject: [PATCH 1/3] Modified plot_activity_map with new spiketoolkit changes --- spikewidgets/tests/test_widgets.py | 4 +++ .../widgets/mapswidget/activitymapwidget.py | 35 ++++++++++++------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/spikewidgets/tests/test_widgets.py b/spikewidgets/tests/test_widgets.py index d13db58..fe8be29 100644 --- a/spikewidgets/tests/test_widgets.py +++ b/spikewidgets/tests/test_widgets.py @@ -25,6 +25,10 @@ def test_spectrogram(self): def test_geometry(self): sw.plot_electrode_geometry(self._RX) + def test_activitymap(self): + sw.plot_activity_map(self._RX, activity='rate') + sw.plot_activity_map(self._RX, activity='amplitude') + def test_unitwaveforms(self): sw.plot_unit_waveforms(self._RX, self._SX) fig, axes = plt.subplots(self.num_units, 1) diff --git a/spikewidgets/widgets/mapswidget/activitymapwidget.py b/spikewidgets/widgets/mapswidget/activitymapwidget.py index 81a2355..00dcaf1 100644 --- a/spikewidgets/widgets/mapswidget/activitymapwidget.py +++ b/spikewidgets/widgets/mapswidget/activitymapwidget.py @@ -2,13 +2,13 @@ import spiketoolkit as st import matplotlib.pylab as plt import matplotlib as mpl -from mpl_toolkits.axes_grid1 import make_axes_locatable from ..utils import LabeledRectangle from spikewidgets.widgets.basewidget import BaseWidget from mpl_toolkits.axes_grid1.inset_locator import inset_axes -def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', background='on', label_color='r', +def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate', + cmap='viridis', background='on', label_color='r', transpose=False, frame=False, colorbar=False, colorbar_bbox=None, colorbar_orientation='vertical', colorbar_width=0.02, recompute_info=False, ax=None, figure=None): @@ -23,6 +23,8 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', The channel ids to display. trange: list List with start time and end time + activity: str + 'rate' or 'amplitude'. If 'rate' the channel spike rate is used. If 'amplitude' the spike amplitude is used. cmap: matplotlib colormap The colormap to be used (default 'viridis') background: bool @@ -55,6 +57,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', recording=recording, channel_ids=channel_ids, trange=trange, + activity=activity, background=background, cmap=cmap, label_color=label_color, @@ -73,12 +76,13 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', class ActivityMapWidget(BaseWidget): - def __init__(self, recording, channel_ids, trange, cmap, background, label_color='r', transpose=False, frame=False, - colorbar=False, colorbar_bbox=None, colorbar_orientation='vertical', colorbar_width=0.02, - recompute_info=False, figure=None, ax=None): + def __init__(self, recording, channel_ids, activity, trange, cmap, background, label_color='r', + transpose=False, frame=False, colorbar=False, colorbar_bbox=None, colorbar_orientation='vertical', + colorbar_width=0.02, recompute_info=False, figure=None, ax=None): BaseWidget.__init__(self, figure, ax) self._recording = recording self._channel_ids = channel_ids + self._activity = activity self._trange = trange self._transpose = transpose self._cmap = cmap @@ -92,6 +96,7 @@ def __init__(self, recording, channel_ids, trange, cmap, background, label_color self._recompute_info = recompute_info self.colorbar = None self.name = 'ActivityMap' + assert activity in ['rate', 'amplitude'], "'activity' can be either 'rate' or 'amplitude'" assert 'location' in self._recording.get_shared_channel_property_names(), "Activity map requires 'location'" \ "property" @@ -106,13 +111,14 @@ def _do_plot(self): self._trange = [int(t * self._recording.get_sampling_frequency()) for t in self._trange] locations = self._recording.get_channel_locations(channel_ids=self._channel_ids) - activity = st.postprocessing.compute_channel_spiking_activity(self._recording, - start_frame=self._trange[0], - end_frame=self._trange[1], - method='detection', - align=False, - recompute_info=self._recompute_info, - verbose=False) + spike_rates, spike_amplitudes = st.postprocessing.compute_channel_spiking_activity(self._recording, + start_frame=self._trange[0], + end_frame=self._trange[1], + method='detection', + align=False, + recompute_info= + self._recompute_info, + verbose=False) if self._transpose: locations = np.roll(locations, 1, axis=1) @@ -143,6 +149,11 @@ def _do_plot(self): elec_x = 0.9 * pitch_x elec_y = 0.9 * pitch_y + if self._activity == 'rate': + activity = spike_rates + else: # amplitude + activity = spike_amplitudes + max_activity = np.round(np.max(activity), 2) # normalize From 5b23681a4580a6a1258519f41db3d4f40d687393 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 18 Nov 2020 14:13:05 +0100 Subject: [PATCH 2/3] Added log option and fix spike amplitudes map --- .../widgets/mapswidget/activitymapwidget.py | 34 +++++++++++++------ .../widgets/mapswidget/templatemapswidget.py | 3 ++ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/spikewidgets/widgets/mapswidget/activitymapwidget.py b/spikewidgets/widgets/mapswidget/activitymapwidget.py index 00dcaf1..87c3f2d 100644 --- a/spikewidgets/widgets/mapswidget/activitymapwidget.py +++ b/spikewidgets/widgets/mapswidget/activitymapwidget.py @@ -7,7 +7,7 @@ from mpl_toolkits.axes_grid1.inset_locator import inset_axes -def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate', +def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate', log=False, cmap='viridis', background='on', label_color='r', transpose=False, frame=False, colorbar=False, colorbar_bbox=None, colorbar_orientation='vertical', colorbar_width=0.02, recompute_info=False, @@ -20,19 +20,21 @@ def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate', recording: RecordingExtractor The recordng extractor object channel_ids: list - The channel ids to display. + The channel ids to display trange: list List with start time and end time activity: str - 'rate' or 'amplitude'. If 'rate' the channel spike rate is used. If 'amplitude' the spike amplitude is used. + 'rate' or 'amplitude'. If 'rate' the channel spike rate is used. If 'amplitude' the spike amplitude is used + log: bool + If True, log scale is used cmap: matplotlib colormap The colormap to be used (default 'viridis') background: bool If True, a background is added in between electrodes transpose: bool, optional, default: False - Swap x and y channel coordinates if True. + Swap x and y channel coordinates if True frame: bool, optional, default: False - Draw a frame around the array if True. + Draw a frame around the array if True colorbar: bool If True, a colorbar is displayed colorbar_bbox: bbox @@ -58,6 +60,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate', channel_ids=channel_ids, trange=trange, activity=activity, + log=log, background=background, cmap=cmap, label_color=label_color, @@ -76,7 +79,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate', class ActivityMapWidget(BaseWidget): - def __init__(self, recording, channel_ids, activity, trange, cmap, background, label_color='r', + def __init__(self, recording, channel_ids, activity, log, trange, cmap, background, label_color='r', transpose=False, frame=False, colorbar=False, colorbar_bbox=None, colorbar_orientation='vertical', colorbar_width=0.02, recompute_info=False, figure=None, ax=None): BaseWidget.__init__(self, figure, ax) @@ -84,6 +87,7 @@ def __init__(self, recording, channel_ids, activity, trange, cmap, background, l self._channel_ids = channel_ids self._activity = activity self._trange = trange + self._log = log self._transpose = transpose self._cmap = cmap self._frame = frame @@ -152,12 +156,18 @@ def _do_plot(self): if self._activity == 'rate': activity = spike_rates else: # amplitude - activity = spike_amplitudes + activity = np.abs(spike_amplitudes) max_activity = np.round(np.max(activity), 2) + min_activity = np.round(np.min(activity), 2) + + if self._log: + if np.any(activity < 1): + activity += (1 - np.min(activity)) + activity = np.log(activity) # normalize - activity -= np.min(activity) + activity -= (np.min(activity) + 1e-5) activity /= np.ptp(activity) for (loc, act, ch) in zip(locations, activity, self._recording.get_channel_ids()): @@ -209,13 +219,17 @@ def _do_plot(self): cax = inset_axes(self.ax, width="100%", height="100%", bbox_to_anchor=bbox, bbox_transform=self.ax.transData) - scalable = mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=0, vmax=1), cmap=self._cmap) + if self._log: + norm = mpl.colors.LogNorm(vmin=1e-5, vmax=1) + else: + norm = mpl.colors.Normalize(vmin=0, vmax=1) + scalable = mpl.cm.ScalarMappable(norm=norm, cmap=self._cmap) self.colorbar = self.figure.colorbar(scalable, cax=cax, orientation=self._colorbar_orient)#, shrink=0.5) cax.yaxis.set_ticks_position('left') cax.yaxis.set_label_position('left') self.colorbar.set_ticks((0, 1)) - self.colorbar.set_ticklabels((0, max_activity)) + self.colorbar.set_ticklabels((min_activity, max_activity)) if self._colorbar_orient == 'vertical': rotation = 90 else: diff --git a/spikewidgets/widgets/mapswidget/templatemapswidget.py b/spikewidgets/widgets/mapswidget/templatemapswidget.py index b479cf8..91f1430 100644 --- a/spikewidgets/widgets/mapswidget/templatemapswidget.py +++ b/spikewidgets/widgets/mapswidget/templatemapswidget.py @@ -150,7 +150,10 @@ def _do_plot(self): for i, (template, unit) in enumerate(zip(templates, unit_ids)): ax = self.get_tiled_ax(i, nrows, ncols) temp_map = np.abs(fun(template, axis=1)) + if self._log: + if np.any(temp_map < 1): + temp_map += (1 - np.min(temp_map)) temp_map = np.log(temp_map) # normalize From cc8bc7d0a7cb81abeab4bd7e306e6cfe6f0cdaad Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 22 Nov 2020 17:06:49 +0100 Subject: [PATCH 3/3] Colorbar label fix --- spikewidgets/widgets/mapswidget/activitymapwidget.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/spikewidgets/widgets/mapswidget/activitymapwidget.py b/spikewidgets/widgets/mapswidget/activitymapwidget.py index 87c3f2d..71f4da1 100644 --- a/spikewidgets/widgets/mapswidget/activitymapwidget.py +++ b/spikewidgets/widgets/mapswidget/activitymapwidget.py @@ -234,6 +234,7 @@ def _do_plot(self): rotation = 90 else: rotation = 0 - self.colorbar.set_label('Sp/s', rotation=rotation) - - + if self._activity == 'rate': + self.colorbar.set_label('Sp/s', rotation=rotation) + else: + self.colorbar.set_label('Amp.', rotation=rotation)