diff --git a/spikewidgets/tests/test_widgets.py b/spikewidgets/tests/test_widgets.py index d13db58..9580ae7 100644 --- a/spikewidgets/tests/test_widgets.py +++ b/spikewidgets/tests/test_widgets.py @@ -3,6 +3,12 @@ import spikecomparison as sc import matplotlib.pyplot as plt import unittest +import sys + +if sys.platform == "win32": + memmaps = [False] +else: + memmaps = [False, True] class TestWidgets(unittest.TestCase): @@ -26,17 +32,20 @@ def test_geometry(self): sw.plot_electrode_geometry(self._RX) def test_unitwaveforms(self): - sw.plot_unit_waveforms(self._RX, self._SX) - fig, axes = plt.subplots(self.num_units, 1) - sw.plot_unit_waveforms(self._RX, self._SX, axes=axes) + for m in memmaps: + sw.plot_unit_waveforms(self._RX, self._SX, memmap=m) + fig, axes = plt.subplots(self.num_units, 1) + sw.plot_unit_waveforms(self._RX, self._SX, axes=axes, memmap=m) def test_unittemplates(self): - sw.plot_unit_templates(self._RX, self._SX) - fig, axes = plt.subplots(self.num_units, 1) - sw.plot_unit_templates(self._RX, self._SX, axes=axes) + for m in memmaps: + sw.plot_unit_templates(self._RX, self._SX, memmap=m) + fig, axes = plt.subplots(self.num_units, 1) + sw.plot_unit_templates(self._RX, self._SX, axes=axes, memmap=m) def test_unittemplatemaps(self): - sw.plot_unit_template_maps(self._RX, self._SX) + for m in memmaps: + sw.plot_unit_template_maps(self._RX, self._SX, memmap=m) def test_ampdist(self): sw.plot_amplitudes_distribution(self._RX, self._SX) @@ -49,9 +58,10 @@ def test_amptime(self): sw.plot_amplitudes_timeseries(self._RX, self._SX, axes=axes) def test_features(self): - sw.plot_pca_features(self._RX, self._SX) - fig, axes = plt.subplots(self.num_units, 1) - sw.plot_pca_features(self._RX, self._SX, axes=axes) + for m in memmaps: + sw.plot_pca_features(self._RX, self._SX, memap=m) + fig, axes = plt.subplots(self.num_units, 1) + sw.plot_pca_features(self._RX, self._SX, axes=axes, memap=m) def test_ach(self): sw.plot_autocorrelograms(self._SX, bin_size=1, window=10) diff --git a/spikewidgets/widgets/featurewidgets/pcawidget.py b/spikewidgets/widgets/featurewidgets/pcawidget.py index f166892..c5bb30a 100644 --- a/spikewidgets/widgets/featurewidgets/pcawidget.py +++ b/spikewidgets/widgets/featurewidgets/pcawidget.py @@ -5,7 +5,7 @@ def plot_pca_features(recording, sorting, unit_ids=None, max_spikes_per_unit=100, nproj=4, colormap=None, - figure=None, ax=None, axes=None): + figure=None, ax=None, axes=None, **pca_kwargs): """ Plots unit PCA features on best projections. @@ -30,6 +30,7 @@ def plot_pca_features(recording, sorting, unit_ids=None, max_spikes_per_unit=100 axes: list of matplotlib axes The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax and figure parameters are ignored + pca_kwargs: keyword arguments for st.postprocessing.compute_unit_pca_scores() Returns @@ -46,35 +47,30 @@ def plot_pca_features(recording, sorting, unit_ids=None, max_spikes_per_unit=100 colormap=colormap, figure=figure, ax=ax, - axes=axes + axes=axes, + **pca_kwargs ) W.plot() return W class PCAWidget(BaseMultiWidget): - def __init__(self, *, recording, sorting, unit_ids=None, max_spikes_per_unit=100, nproj=4, colormap=None, - figure=None, ax=None, axes=None, save_as_features=False, save_waveforms_as_features=False): + def __init__(self, *, recording, sorting, unit_ids=None, nproj=4, colormap=None, + figure=None, ax=None, axes=None, **pca_kwargs): BaseMultiWidget.__init__(self, figure, ax, axes) self._sorting = sorting self._recording = recording self._unit_ids = unit_ids - self._nproj = nproj - self._max_spikes_per_unit = max_spikes_per_unit self._pca_scores = None + self._nproj = nproj self._colormap = colormap - self._save_as_features = save_as_features - self._save_waveforms_as_features = save_waveforms_as_features + self._pca_kwargs = pca_kwargs self.name = 'Feature' def _compute_pca(self): self._pca_scores = st.postprocessing.compute_unit_pca_scores(recording=self._recording, sorting=self._sorting, - by_electrode=True, - max_spikes_per_unit=self._max_spikes_per_unit, - save_as_features=self._save_as_features, - save_waveforms_as_features= - self._save_waveforms_as_features) + **self._pca_kwargs) def plot(self): self._do_plot() @@ -106,8 +102,7 @@ def _do_plot(self): distances.append(dist) proj.append([ch1, pc1, ch2, pc2]) - list_best_proj = np.array( - proj)[np.argsort(distances)[::-1][:self._nproj]] + list_best_proj = np.array(proj)[np.argsort(distances)[::-1][:self._nproj]] self._plot_proj_multi(list_best_proj) def compute_cluster_average_distance(self, pc1, ch1, pc2, ch2): diff --git a/spikewidgets/widgets/mapswidget/activitymapwidget.py b/spikewidgets/widgets/mapswidget/activitymapwidget.py index 0814480..40fc255 100644 --- a/spikewidgets/widgets/mapswidget/activitymapwidget.py +++ b/spikewidgets/widgets/mapswidget/activitymapwidget.py @@ -6,7 +6,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', background='on', label_color='r', - transpose=False, frame=False, ax=None, figure=None): + transpose=False, frame=False, ax=None, figure=None, **activity_kwargs): """ Plots spike rate (estimated using simple threshold detector) as 2D activity map. @@ -28,6 +28,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', The figure to be used. If not given a figure is created ax: matplotlib axis The axis to be used. If not given an axis is created + activity_kwargs: keyword arguments for st.postprocessing.compute_channel_spiking_activity() Returns ------- @@ -45,6 +46,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', frame=frame, figure=figure, ax=ax, + **activity_kwargs ) W.plot() return W @@ -53,7 +55,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', class ActivityMapWidget(BaseWidget): def __init__(self, recording, channel_ids, trange, cmap, background, label_color='r', transpose=False, frame=False, - figure=None, ax=None): + figure=None, ax=None, **activity_kwargs): BaseWidget.__init__(self, figure, ax) self._recording = recording self._channel_ids = channel_ids @@ -63,6 +65,7 @@ def __init__(self, recording, channel_ids, trange, cmap, background, label_color self._frame = frame self._bg = background self._label_color = label_color + self._activity_kwargs = activity_kwargs self.name = 'ActivityMap' assert 'location' in self._recording.get_shared_channel_property_names(), "Activity map requires 'location'" \ "property" @@ -80,7 +83,8 @@ def _do_plot(self): locations = self._recording.get_channel_locations(channel_ids=self._channel_ids) activity = st.postprocessing.compute_channel_spiking_activity(self._recording, start_frame=self._trange[0], - end_frame=self._trange[1]) + end_frame=self._trange[1], + **self._activity_kwargs) if self._transpose: locations = np.roll(locations, 1, axis=1) diff --git a/spikewidgets/widgets/mapswidget/templatemapswidget.py b/spikewidgets/widgets/mapswidget/templatemapswidget.py index b479cf8..c4d1f5a 100644 --- a/spikewidgets/widgets/mapswidget/templatemapswidget.py +++ b/spikewidgets/widgets/mapswidget/templatemapswidget.py @@ -6,8 +6,8 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None, peak='neg', log=False, ncols=10, - ms_before=1., ms_after=2., max_spikes_per_unit=100, background='on', cmap='viridis', - label_color='r', figure=None, ax=None, axes=None): + background='on', cmap='viridis', label_color='r', figure=None, ax=None, axes=None, + **templates_kwargs): """ Plots sorting comparison confusion matrix. @@ -27,12 +27,6 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None, If True, log scale is used ncols: int Number of columns if multiple units are displayed - ms_before: float - Time before peak (ms) - ms_after: float - Time after peak (ms) - max_spikes_per_unit: int - Maximum number of spikes to display per unit. background: str 'on' or 'off' cmap: matplotlib colormap @@ -46,6 +40,8 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None, axes: list of matplotlib axes The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax and figure parameters are ignored + templates_kwargs: keyword arguments for st.postprocessing.get_unit_templates() + Returns ------- @@ -57,26 +53,24 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None, sorting=sorting, channel_ids=channel_ids, unit_ids=unit_ids, - max_spikes_per_unit=max_spikes_per_unit, peak=peak, log=log, ncols=ncols, - ms_before=ms_before, - ms_after=ms_after, background=background, cmap=cmap, label_color=label_color, figure=figure, ax=ax, - axes=axes + axes=axes, + **templates_kwargs ) W.plot() return W class UnitTemplateMapsWidget(BaseMultiWidget): - def __init__(self, recording, sorting, channel_ids, unit_ids, peak, log, ncols, max_spikes_per_unit, ms_before, - ms_after, background, cmap, label_color='r', figure=None, ax=None, axes=None): + def __init__(self, recording, sorting, channel_ids, unit_ids, peak, log, ncols, background, cmap, label_color='r', + figure=None, ax=None, axes=None, **template_kwargs): BaseMultiWidget.__init__(self, figure, ax, axes) self._recording = recording self._sorting = sorting @@ -85,12 +79,10 @@ def __init__(self, recording, sorting, channel_ids, unit_ids, peak, log, ncols, self._peak = peak self._log = log self._ncols = ncols - self._ms_before = ms_before - self._ms_after = ms_after - self._max_spikes_per_unit = max_spikes_per_unit self._bg = background self._cmap = cmap self._label_color = label_color + self._template_kwargs = template_kwargs self.name = 'UnitTemplateMaps' assert 'location' in self._recording.get_shared_channel_property_names(), "Activity map requires 'location'" \ "property" @@ -102,8 +94,7 @@ def _do_plot(self): locations = self._recording.get_channel_locations(channel_ids=self._channel_ids) templates = st.postprocessing.get_unit_templates(self._recording, self._sorting, channel_ids=self._channel_ids, unit_ids=self._unit_ids, - max_spikes_per_unit=self._max_spikes_per_unit, - ms_before=self._ms_before, ms_after=self._ms_after) + **self._template_kwargs) if self._channel_ids is None: channel_ids = self._recording.get_channel_ids() else: diff --git a/spikewidgets/widgets/unitwaveformswidget/unitwaveformswidget.py b/spikewidgets/widgets/unitwaveformswidget/unitwaveformswidget.py index 6dea1f1..8a9b80b 100644 --- a/spikewidgets/widgets/unitwaveformswidget/unitwaveformswidget.py +++ b/spikewidgets/widgets/unitwaveformswidget/unitwaveformswidget.py @@ -5,10 +5,9 @@ from spikewidgets.widgets.basewidget import BaseMultiWidget -def plot_unit_waveforms(recording, sorting, channel_ids=None, unit_ids=None, ms_before=1., ms_after=2., - max_spikes_per_unit=100, max_channels=16, channel_locs=True, radius=None, - plot_templates=True, show_all_channels=True, color='k', lw=2, axis_equal=False, - plot_channels=False, set_title=True, figure=None, ax=None, axes=None): +def plot_unit_waveforms(recording, sorting, channel_ids=None, unit_ids=None, channel_locs=True, radius=None, + max_channels=None, plot_templates=True, show_all_channels=True, color='k', lw=2, axis_equal=False, + plot_channels=False, set_title=True, figure=None, ax=None, axes=None, **waveforms_kwargs): """ Plots unit waveforms. @@ -22,12 +21,6 @@ def plot_unit_waveforms(recording, sorting, channel_ids=None, unit_ids=None, ms_ The channel ids to display unit_ids: list List of unit ids. - ms_before: float - Time before peak (ms) - ms_after: float - Time after peak (ms) - max_spikes_per_unit: int - Maximum number of spikes to display per unit max_channels: int Maximum number of largest channels to plot waveform channel_locs: bool @@ -57,6 +50,7 @@ def plot_unit_waveforms(recording, sorting, channel_ids=None, unit_ids=None, ms_ axes: list of matplotlib axes The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax and figure parameters are ignored + waveforms_kwargs: keyword arguments for st.postprocessing.get_unit_waveforms() Returns ------- @@ -68,10 +62,7 @@ def plot_unit_waveforms(recording, sorting, channel_ids=None, unit_ids=None, ms_ sorting=sorting, channel_ids=channel_ids, unit_ids=unit_ids, - max_spikes_per_unit=max_spikes_per_unit, max_channels=max_channels, - ms_before=ms_before, - ms_after=ms_after, channel_locs=channel_locs, plot_templates=plot_templates, figure=figure, @@ -83,16 +74,16 @@ def plot_unit_waveforms(recording, sorting, channel_ids=None, unit_ids=None, ms_ lw=lw, axis_equal=axis_equal, plot_channels=plot_channels, - set_title=set_title + set_title=set_title, + **waveforms_kwargs ) W.plot() return W -def plot_unit_templates(recording, sorting, channel_ids=None, unit_ids=None, ms_before=1., ms_after=2., - max_spikes_per_unit=100, max_channels=16, channel_locs=True, radius=None, - show_all_channels=True, color='k', lw=2, axis_equal=False, - plot_channels=False, set_title=True, figure=None, ax=None, axes=None): +def plot_unit_templates(recording, sorting, channel_ids=None, unit_ids=None, max_channels=None, channel_locs=True, + radius=None, show_all_channels=True, color='k', lw=2, axis_equal=False, + plot_channels=False, set_title=True, figure=None, ax=None, axes=None, **template_kwargs): """ Plots unit waveforms. @@ -106,12 +97,6 @@ def plot_unit_templates(recording, sorting, channel_ids=None, unit_ids=None, ms_ The channel ids to display unit_ids: list List of unit ids. - ms_before: float - Time before peak (ms) - ms_after: float - Time after peak (ms) - max_spikes_per_unit: int - Maximum number of spikes to display per unit max_channels: int Maximum number of largest channels to plot waveform channel_locs: bool @@ -139,6 +124,7 @@ def plot_unit_templates(recording, sorting, channel_ids=None, unit_ids=None, ms_ axes: list of matplotlib axes The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax and figure parameters are ignored + template_kwargs: keyword arguments for st.postprocessing.get_unit_templates() Returns @@ -151,10 +137,7 @@ def plot_unit_templates(recording, sorting, channel_ids=None, unit_ids=None, ms_ sorting=sorting, channel_ids=channel_ids, unit_ids=unit_ids, - max_spikes_per_unit=max_spikes_per_unit, max_channels=max_channels, - ms_before=ms_before, - ms_after=ms_after, channel_locs=channel_locs, figure=figure, ax=ax, @@ -165,18 +148,19 @@ def plot_unit_templates(recording, sorting, channel_ids=None, unit_ids=None, ms_ lw=lw, axis_equal=axis_equal, plot_channels=plot_channels, - set_title=set_title + set_title=set_title, + **template_kwargs ) W.plot() return W class UnitWaveformsWidget(BaseMultiWidget): - def __init__(self, *, recording, sorting, channel_ids=None, unit_ids=None, max_spikes_per_unit=50, - max_channels=16, ms_before=1., ms_after=2., channel_locs=True, plot_waveforms=True, + def __init__(self, *, recording, sorting, channel_ids=None, unit_ids=None, max_channels=None, + channel_locs=True, plot_waveforms=True, plot_templates=True, radius=None, show_all_channels=True, figure=None, ax=None, axes=None, color='k', lw=2, axis_equal=False, - plot_channels=False, set_title=True): + plot_channels=False, set_title=True, **kwargs): BaseMultiWidget.__init__(self, figure, ax, axes) self._recording = recording self._sorting = sorting @@ -185,9 +169,6 @@ def __init__(self, *, recording, sorting, channel_ids=None, unit_ids=None, max_s max_channels = recording.get_num_channels() self._max_channels = max_channels self._unit_ids = unit_ids - self._ms_before = ms_before - self._ms_after = ms_after - self._max_spikes_per_unit = max_spikes_per_unit self._ch_locs = channel_locs self.name = 'UnitWaveforms' self._plot_waveforms = plot_waveforms @@ -199,6 +180,8 @@ def __init__(self, *, recording, sorting, channel_ids=None, unit_ids=None, max_s self._axis_equal = axis_equal self._plot_channels = plot_channels self._set_title = set_title + self._kwargs = kwargs + print(self._kwargs) def plot(self): self._do_plot() @@ -224,20 +207,15 @@ def _do_plot(self): if self._plot_waveforms: random_wf = st.postprocessing.get_unit_waveforms(recording=self._recording, sorting=self._sorting, unit_ids=[unit_id], channel_ids=channel_ids, - ms_before=self._ms_before, ms_after=self._ms_after, - max_spikes_per_unit=self._max_spikes_per_unit, - save_as_features=False)[0] + **self._kwargs)[0] else: random_wf = None template = st.postprocessing.get_unit_templates(recording=self._recording, sorting=self._sorting, unit_ids=[unit_id], channel_ids=channel_ids, - ms_before=self._ms_before, - ms_after=self._ms_after, - max_spikes_per_unit=self._max_spikes_per_unit, - save_as_features=False)[0] + **self._kwargs)[0] if self._radius is None: - if self._max_channels < self._recording.get_num_channels(): + if self._max_channels < template.shape[0]: peak_idx = np.unravel_index(np.argmax(np.abs(template)), template.shape)[1] max_channel_idxs = np.argsort(np.abs(template[:, peak_idx]))[::-1][:self._max_channels] @@ -296,17 +274,17 @@ def _plot_spike_shapes_multi(self, list_spikes, max_channels_list, *, ncols=5, * class UnitTemplatesWidget(UnitWaveformsWidget): - def __init__(self, *, recording, sorting, channel_ids=None, unit_ids=None, max_spikes_per_unit=50, - max_channels=16, ms_before=1., ms_after=2., channel_locs=True, figure=None, show_all_channels=True, + def __init__(self, *, recording, sorting, channel_ids=None, unit_ids=None, + max_channels=None, channel_locs=True, figure=None, show_all_channels=True, ax=None, axes=None, radius=None, color='k', lw=2, axis_equal=False, plot_channels=False, - set_title=True): + set_title=True, **template_kwargs): UnitWaveformsWidget.__init__(self, recording=recording, sorting=sorting, channel_ids=channel_ids, - unit_ids=unit_ids, max_spikes_per_unit=max_spikes_per_unit, - max_channels=max_channels, ms_before=ms_before, ms_after=ms_after, + unit_ids=unit_ids, max_channels=max_channels, channel_locs=channel_locs, plot_waveforms=False, plot_templates=True, figure=figure, ax=ax, axes=axes, radius=radius, show_all_channels=show_all_channels, color=color, lw=lw, - axis_equal=axis_equal, plot_channels=plot_channels, set_title=set_title) + axis_equal=axis_equal, plot_channels=plot_channels, set_title=set_title, + **template_kwargs) self.name = 'UnitTemplates'