Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Add spiketoolkit kwargs and don't test memmap on windows
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Nov 27, 2020
1 parent 0c4ec42 commit 26988e7
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 95 deletions.
30 changes: 20 additions & 10 deletions spikewidgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import spikecomparison as sc
import matplotlib.pyplot as plt
import unittest
import sys

if sys.platform == "win32":
memmaps = [False]
else:
memmaps = [False, True]


class TestWidgets(unittest.TestCase):
Expand All @@ -26,17 +32,20 @@ def test_geometry(self):
sw.plot_electrode_geometry(self._RX)

def test_unitwaveforms(self):
sw.plot_unit_waveforms(self._RX, self._SX)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_unit_waveforms(self._RX, self._SX, axes=axes)
for m in memmaps:
sw.plot_unit_waveforms(self._RX, self._SX, memmap=m)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_unit_waveforms(self._RX, self._SX, axes=axes, memmap=m)

def test_unittemplates(self):
sw.plot_unit_templates(self._RX, self._SX)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_unit_templates(self._RX, self._SX, axes=axes)
for m in memmaps:
sw.plot_unit_templates(self._RX, self._SX, memmap=m)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_unit_templates(self._RX, self._SX, axes=axes, memmap=m)

def test_unittemplatemaps(self):
sw.plot_unit_template_maps(self._RX, self._SX)
for m in memmaps:
sw.plot_unit_template_maps(self._RX, self._SX, memmap=m)

def test_ampdist(self):
sw.plot_amplitudes_distribution(self._RX, self._SX)
Expand All @@ -49,9 +58,10 @@ def test_amptime(self):
sw.plot_amplitudes_timeseries(self._RX, self._SX, axes=axes)

def test_features(self):
sw.plot_pca_features(self._RX, self._SX)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_pca_features(self._RX, self._SX, axes=axes)
for m in memmaps:
sw.plot_pca_features(self._RX, self._SX, memap=m)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_pca_features(self._RX, self._SX, axes=axes, memap=m)

def test_ach(self):
sw.plot_autocorrelograms(self._SX, bin_size=1, window=10)
Expand Down
25 changes: 10 additions & 15 deletions spikewidgets/widgets/featurewidgets/pcawidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def plot_pca_features(recording, sorting, unit_ids=None, max_spikes_per_unit=100, nproj=4, colormap=None,
figure=None, ax=None, axes=None):
figure=None, ax=None, axes=None, **pca_kwargs):
"""
Plots unit PCA features on best projections.
Expand All @@ -30,6 +30,7 @@ def plot_pca_features(recording, sorting, unit_ids=None, max_spikes_per_unit=100
axes: list of matplotlib axes
The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax
and figure parameters are ignored
pca_kwargs: keyword arguments for st.postprocessing.compute_unit_pca_scores()
Returns
Expand All @@ -46,35 +47,30 @@ def plot_pca_features(recording, sorting, unit_ids=None, max_spikes_per_unit=100
colormap=colormap,
figure=figure,
ax=ax,
axes=axes
axes=axes,
**pca_kwargs
)
W.plot()
return W


class PCAWidget(BaseMultiWidget):
def __init__(self, *, recording, sorting, unit_ids=None, max_spikes_per_unit=100, nproj=4, colormap=None,
figure=None, ax=None, axes=None, save_as_features=False, save_waveforms_as_features=False):
def __init__(self, *, recording, sorting, unit_ids=None, nproj=4, colormap=None,
figure=None, ax=None, axes=None, **pca_kwargs):
BaseMultiWidget.__init__(self, figure, ax, axes)
self._sorting = sorting
self._recording = recording
self._unit_ids = unit_ids
self._nproj = nproj
self._max_spikes_per_unit = max_spikes_per_unit
self._pca_scores = None
self._nproj = nproj
self._colormap = colormap
self._save_as_features = save_as_features
self._save_waveforms_as_features = save_waveforms_as_features
self._pca_kwargs = pca_kwargs
self.name = 'Feature'

def _compute_pca(self):
self._pca_scores = st.postprocessing.compute_unit_pca_scores(recording=self._recording,
sorting=self._sorting,
by_electrode=True,
max_spikes_per_unit=self._max_spikes_per_unit,
save_as_features=self._save_as_features,
save_waveforms_as_features=
self._save_waveforms_as_features)
**self._pca_kwargs)

def plot(self):
self._do_plot()
Expand Down Expand Up @@ -106,8 +102,7 @@ def _do_plot(self):
distances.append(dist)
proj.append([ch1, pc1, ch2, pc2])

list_best_proj = np.array(
proj)[np.argsort(distances)[::-1][:self._nproj]]
list_best_proj = np.array(proj)[np.argsort(distances)[::-1][:self._nproj]]
self._plot_proj_multi(list_best_proj)

def compute_cluster_average_distance(self, pc1, ch1, pc2, ch2):
Expand Down
10 changes: 7 additions & 3 deletions spikewidgets/widgets/mapswidget/activitymapwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', background='on', label_color='r',
transpose=False, frame=False, ax=None, figure=None):
transpose=False, frame=False, ax=None, figure=None, **activity_kwargs):
"""
Plots spike rate (estimated using simple threshold detector) as 2D activity map.
Expand All @@ -28,6 +28,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis',
The figure to be used. If not given a figure is created
ax: matplotlib axis
The axis to be used. If not given an axis is created
activity_kwargs: keyword arguments for st.postprocessing.compute_channel_spiking_activity()
Returns
-------
Expand All @@ -45,6 +46,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis',
frame=frame,
figure=figure,
ax=ax,
**activity_kwargs
)
W.plot()
return W
Expand All @@ -53,7 +55,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis',
class ActivityMapWidget(BaseWidget):

def __init__(self, recording, channel_ids, trange, cmap, background, label_color='r', transpose=False, frame=False,
figure=None, ax=None):
figure=None, ax=None, **activity_kwargs):
BaseWidget.__init__(self, figure, ax)
self._recording = recording
self._channel_ids = channel_ids
Expand All @@ -63,6 +65,7 @@ def __init__(self, recording, channel_ids, trange, cmap, background, label_color
self._frame = frame
self._bg = background
self._label_color = label_color
self._activity_kwargs = activity_kwargs
self.name = 'ActivityMap'
assert 'location' in self._recording.get_shared_channel_property_names(), "Activity map requires 'location'" \
"property"
Expand All @@ -80,7 +83,8 @@ def _do_plot(self):
locations = self._recording.get_channel_locations(channel_ids=self._channel_ids)
activity = st.postprocessing.compute_channel_spiking_activity(self._recording,
start_frame=self._trange[0],
end_frame=self._trange[1])
end_frame=self._trange[1],
**self._activity_kwargs)
if self._transpose:
locations = np.roll(locations, 1, axis=1)

Expand Down
29 changes: 10 additions & 19 deletions spikewidgets/widgets/mapswidget/templatemapswidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None, peak='neg', log=False, ncols=10,
ms_before=1., ms_after=2., max_spikes_per_unit=100, background='on', cmap='viridis',
label_color='r', figure=None, ax=None, axes=None):
background='on', cmap='viridis', label_color='r', figure=None, ax=None, axes=None,
**templates_kwargs):
"""
Plots sorting comparison confusion matrix.
Expand All @@ -27,12 +27,6 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None,
If True, log scale is used
ncols: int
Number of columns if multiple units are displayed
ms_before: float
Time before peak (ms)
ms_after: float
Time after peak (ms)
max_spikes_per_unit: int
Maximum number of spikes to display per unit.
background: str
'on' or 'off'
cmap: matplotlib colormap
Expand All @@ -46,6 +40,8 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None,
axes: list of matplotlib axes
The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax
and figure parameters are ignored
templates_kwargs: keyword arguments for st.postprocessing.get_unit_templates()
Returns
-------
Expand All @@ -57,26 +53,24 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None,
sorting=sorting,
channel_ids=channel_ids,
unit_ids=unit_ids,
max_spikes_per_unit=max_spikes_per_unit,
peak=peak,
log=log,
ncols=ncols,
ms_before=ms_before,
ms_after=ms_after,
background=background,
cmap=cmap,
label_color=label_color,
figure=figure,
ax=ax,
axes=axes
axes=axes,
**templates_kwargs
)
W.plot()
return W


class UnitTemplateMapsWidget(BaseMultiWidget):
def __init__(self, recording, sorting, channel_ids, unit_ids, peak, log, ncols, max_spikes_per_unit, ms_before,
ms_after, background, cmap, label_color='r', figure=None, ax=None, axes=None):
def __init__(self, recording, sorting, channel_ids, unit_ids, peak, log, ncols, background, cmap, label_color='r',
figure=None, ax=None, axes=None, **template_kwargs):
BaseMultiWidget.__init__(self, figure, ax, axes)
self._recording = recording
self._sorting = sorting
Expand All @@ -85,12 +79,10 @@ def __init__(self, recording, sorting, channel_ids, unit_ids, peak, log, ncols,
self._peak = peak
self._log = log
self._ncols = ncols
self._ms_before = ms_before
self._ms_after = ms_after
self._max_spikes_per_unit = max_spikes_per_unit
self._bg = background
self._cmap = cmap
self._label_color = label_color
self._template_kwargs = template_kwargs
self.name = 'UnitTemplateMaps'
assert 'location' in self._recording.get_shared_channel_property_names(), "Activity map requires 'location'" \
"property"
Expand All @@ -102,8 +94,7 @@ def _do_plot(self):
locations = self._recording.get_channel_locations(channel_ids=self._channel_ids)
templates = st.postprocessing.get_unit_templates(self._recording, self._sorting, channel_ids=self._channel_ids,
unit_ids=self._unit_ids,
max_spikes_per_unit=self._max_spikes_per_unit,
ms_before=self._ms_before, ms_after=self._ms_after)
**self._template_kwargs)
if self._channel_ids is None:
channel_ids = self._recording.get_channel_ids()
else:
Expand Down
Loading

0 comments on commit 26988e7

Please sign in to comment.