Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #70 from SpikeInterface/plot_templates_fix
Browse files Browse the repository at this point in the history
Plot templates fix
  • Loading branch information
alejoe91 authored Nov 27, 2020
2 parents 0c27a8c + 6339d6d commit 2abfdac
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 130 deletions.
30 changes: 20 additions & 10 deletions spikewidgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import spikecomparison as sc
import matplotlib.pyplot as plt
import unittest
import sys

if sys.platform == "win32":
memmaps = [False]
else:
memmaps = [False, True]


class TestWidgets(unittest.TestCase):
Expand Down Expand Up @@ -30,17 +36,20 @@ def test_activitymap(self):
sw.plot_activity_map(self._RX, activity='amplitude')

def test_unitwaveforms(self):
sw.plot_unit_waveforms(self._RX, self._SX)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_unit_waveforms(self._RX, self._SX, axes=axes)
for m in memmaps:
sw.plot_unit_waveforms(self._RX, self._SX, memmap=m)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_unit_waveforms(self._RX, self._SX, axes=axes, memmap=m)

def test_unittemplates(self):
sw.plot_unit_templates(self._RX, self._SX)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_unit_templates(self._RX, self._SX, axes=axes)
for m in memmaps:
sw.plot_unit_templates(self._RX, self._SX, memmap=m)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_unit_templates(self._RX, self._SX, axes=axes, memmap=m)

def test_unittemplatemaps(self):
sw.plot_unit_template_maps(self._RX, self._SX)
for m in memmaps:
sw.plot_unit_template_maps(self._RX, self._SX, memmap=m)

def test_ampdist(self):
sw.plot_amplitudes_distribution(self._RX, self._SX)
Expand All @@ -53,9 +62,10 @@ def test_amptime(self):
sw.plot_amplitudes_timeseries(self._RX, self._SX, axes=axes)

def test_features(self):
sw.plot_pca_features(self._RX, self._SX)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_pca_features(self._RX, self._SX, axes=axes)
for m in memmaps:
sw.plot_pca_features(self._RX, self._SX, memap=m)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_pca_features(self._RX, self._SX, axes=axes, memap=m)

def test_ach(self):
sw.plot_autocorrelograms(self._SX, bin_size=1, window=10)
Expand Down
25 changes: 10 additions & 15 deletions spikewidgets/widgets/featurewidgets/pcawidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def plot_pca_features(recording, sorting, unit_ids=None, max_spikes_per_unit=100, nproj=4, colormap=None,
figure=None, ax=None, axes=None):
figure=None, ax=None, axes=None, **pca_kwargs):
"""
Plots unit PCA features on best projections.
Expand All @@ -30,6 +30,7 @@ def plot_pca_features(recording, sorting, unit_ids=None, max_spikes_per_unit=100
axes: list of matplotlib axes
The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax
and figure parameters are ignored
pca_kwargs: keyword arguments for st.postprocessing.compute_unit_pca_scores()
Returns
Expand All @@ -46,35 +47,30 @@ def plot_pca_features(recording, sorting, unit_ids=None, max_spikes_per_unit=100
colormap=colormap,
figure=figure,
ax=ax,
axes=axes
axes=axes,
**pca_kwargs
)
W.plot()
return W


class PCAWidget(BaseMultiWidget):
def __init__(self, *, recording, sorting, unit_ids=None, max_spikes_per_unit=100, nproj=4, colormap=None,
figure=None, ax=None, axes=None, save_as_features=False, save_waveforms_as_features=False):
def __init__(self, *, recording, sorting, unit_ids=None, nproj=4, colormap=None,
figure=None, ax=None, axes=None, **pca_kwargs):
BaseMultiWidget.__init__(self, figure, ax, axes)
self._sorting = sorting
self._recording = recording
self._unit_ids = unit_ids
self._nproj = nproj
self._max_spikes_per_unit = max_spikes_per_unit
self._pca_scores = None
self._nproj = nproj
self._colormap = colormap
self._save_as_features = save_as_features
self._save_waveforms_as_features = save_waveforms_as_features
self._pca_kwargs = pca_kwargs
self.name = 'Feature'

def _compute_pca(self):
self._pca_scores = st.postprocessing.compute_unit_pca_scores(recording=self._recording,
sorting=self._sorting,
by_electrode=True,
max_spikes_per_unit=self._max_spikes_per_unit,
save_as_features=self._save_as_features,
save_waveforms_as_features=
self._save_waveforms_as_features)
**self._pca_kwargs)

def plot(self):
self._do_plot()
Expand Down Expand Up @@ -106,8 +102,7 @@ def _do_plot(self):
distances.append(dist)
proj.append([ch1, pc1, ch2, pc2])

list_best_proj = np.array(
proj)[np.argsort(distances)[::-1][:self._nproj]]
list_best_proj = np.array(proj)[np.argsort(distances)[::-1][:self._nproj]]
self._plot_proj_multi(list_best_proj)

def compute_cluster_average_distance(self, pc1, ch1, pc2, ch2):
Expand Down
19 changes: 7 additions & 12 deletions spikewidgets/widgets/mapswidget/activitymapwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate', log=False,
cmap='viridis', background='on', label_color='r',
transpose=False, frame=False, colorbar=False, colorbar_bbox=None,
colorbar_orientation='vertical', colorbar_width=0.02, recompute_info=False,
ax=None, figure=None):
colorbar_orientation='vertical', colorbar_width=0.02,
ax=None, figure=None, **activity_kwargs):
"""
Plots spike rate (estimated using simple threshold detector) as 2D activity map.
Expand Down Expand Up @@ -43,12 +43,11 @@ def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate',
'vertical' or 'horizontal'
colorbar_width: float
Width of colorbar in figure coordinates (default 0.02)
recompute_info: bool
If True, spike rates are recomputed
figure: matplotlib figure
The figure to be used. If not given a figure is created
ax: matplotlib axis
The axis to be used. If not given an axis is created
activity_kwargs: keyword arguments for st.postprocessing.compute_channel_spiking_activity()
Returns
-------
Expand All @@ -72,7 +71,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate',
colorbar_bbox=colorbar_bbox,
colorbar_orientation=colorbar_orientation,
colorbar_width=colorbar_width,
recompute_info=recompute_info
**activity_kwargs
)
W.plot()
return W
Expand All @@ -81,7 +80,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate',
class ActivityMapWidget(BaseWidget):
def __init__(self, recording, channel_ids, activity, log, trange, cmap, background, label_color='r',
transpose=False, frame=False, colorbar=False, colorbar_bbox=None, colorbar_orientation='vertical',
colorbar_width=0.02, recompute_info=False, figure=None, ax=None):
colorbar_width=0.02, figure=None, ax=None, **activity_kwargs):
BaseWidget.__init__(self, figure, ax)
self._recording = recording
self._channel_ids = channel_ids
Expand All @@ -93,11 +92,11 @@ def __init__(self, recording, channel_ids, activity, log, trange, cmap, backgrou
self._frame = frame
self._bg = background
self._label_color = label_color
self._activity_kwargs = activity_kwargs
self._show_colorbar = colorbar
self._colorbar_bbox = colorbar_bbox
self._colorbar_orient = colorbar_orientation
self._colorbar_width = colorbar_width
self._recompute_info = recompute_info
self.colorbar = None
self.name = 'ActivityMap'
assert activity in ['rate', 'amplitude'], "'activity' can be either 'rate' or 'amplitude'"
Expand All @@ -118,11 +117,7 @@ def _do_plot(self):
spike_rates, spike_amplitudes = st.postprocessing.compute_channel_spiking_activity(self._recording,
start_frame=self._trange[0],
end_frame=self._trange[1],
method='detection',
align=False,
recompute_info=
self._recompute_info,
verbose=False)
**self._activity_kwargs)
if self._transpose:
locations = np.roll(locations, 1, axis=1)

Expand Down
29 changes: 10 additions & 19 deletions spikewidgets/widgets/mapswidget/templatemapswidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None, peak='neg', log=False, ncols=10,
ms_before=1., ms_after=2., max_spikes_per_unit=100, background='on', cmap='viridis',
label_color='r', figure=None, ax=None, axes=None):
background='on', cmap='viridis', label_color='r', figure=None, ax=None, axes=None,
**templates_kwargs):
"""
Plots sorting comparison confusion matrix.
Expand All @@ -27,12 +27,6 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None,
If True, log scale is used
ncols: int
Number of columns if multiple units are displayed
ms_before: float
Time before peak (ms)
ms_after: float
Time after peak (ms)
max_spikes_per_unit: int
Maximum number of spikes to display per unit.
background: str
'on' or 'off'
cmap: matplotlib colormap
Expand All @@ -46,6 +40,8 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None,
axes: list of matplotlib axes
The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax
and figure parameters are ignored
templates_kwargs: keyword arguments for st.postprocessing.get_unit_templates()
Returns
-------
Expand All @@ -57,26 +53,24 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None,
sorting=sorting,
channel_ids=channel_ids,
unit_ids=unit_ids,
max_spikes_per_unit=max_spikes_per_unit,
peak=peak,
log=log,
ncols=ncols,
ms_before=ms_before,
ms_after=ms_after,
background=background,
cmap=cmap,
label_color=label_color,
figure=figure,
ax=ax,
axes=axes
axes=axes,
**templates_kwargs
)
W.plot()
return W


class UnitTemplateMapsWidget(BaseMultiWidget):
def __init__(self, recording, sorting, channel_ids, unit_ids, peak, log, ncols, max_spikes_per_unit, ms_before,
ms_after, background, cmap, label_color='r', figure=None, ax=None, axes=None):
def __init__(self, recording, sorting, channel_ids, unit_ids, peak, log, ncols, background, cmap, label_color='r',
figure=None, ax=None, axes=None, **template_kwargs):
BaseMultiWidget.__init__(self, figure, ax, axes)
self._recording = recording
self._sorting = sorting
Expand All @@ -85,12 +79,10 @@ def __init__(self, recording, sorting, channel_ids, unit_ids, peak, log, ncols,
self._peak = peak
self._log = log
self._ncols = ncols
self._ms_before = ms_before
self._ms_after = ms_after
self._max_spikes_per_unit = max_spikes_per_unit
self._bg = background
self._cmap = cmap
self._label_color = label_color
self._template_kwargs = template_kwargs
self.name = 'UnitTemplateMaps'
assert 'location' in self._recording.get_shared_channel_property_names(), "Activity map requires 'location'" \
"property"
Expand All @@ -102,8 +94,7 @@ def _do_plot(self):
locations = self._recording.get_channel_locations(channel_ids=self._channel_ids)
templates = st.postprocessing.get_unit_templates(self._recording, self._sorting, channel_ids=self._channel_ids,
unit_ids=self._unit_ids,
max_spikes_per_unit=self._max_spikes_per_unit,
ms_before=self._ms_before, ms_after=self._ms_after)
**self._template_kwargs)
if self._channel_ids is None:
channel_ids = self._recording.get_channel_ids()
else:
Expand Down
Loading

0 comments on commit 2abfdac

Please sign in to comment.