diff --git a/spikewidgets/tests/test_widgets.py b/spikewidgets/tests/test_widgets.py index fe8be29..7e1c26c 100644 --- a/spikewidgets/tests/test_widgets.py +++ b/spikewidgets/tests/test_widgets.py @@ -3,6 +3,12 @@ import spikecomparison as sc import matplotlib.pyplot as plt import unittest +import sys + +if sys.platform == "win32": + memmaps = [False] +else: + memmaps = [False, True] class TestWidgets(unittest.TestCase): @@ -30,17 +36,20 @@ def test_activitymap(self): sw.plot_activity_map(self._RX, activity='amplitude') def test_unitwaveforms(self): - sw.plot_unit_waveforms(self._RX, self._SX) - fig, axes = plt.subplots(self.num_units, 1) - sw.plot_unit_waveforms(self._RX, self._SX, axes=axes) + for m in memmaps: + sw.plot_unit_waveforms(self._RX, self._SX, memmap=m) + fig, axes = plt.subplots(self.num_units, 1) + sw.plot_unit_waveforms(self._RX, self._SX, axes=axes, memmap=m) def test_unittemplates(self): - sw.plot_unit_templates(self._RX, self._SX) - fig, axes = plt.subplots(self.num_units, 1) - sw.plot_unit_templates(self._RX, self._SX, axes=axes) + for m in memmaps: + sw.plot_unit_templates(self._RX, self._SX, memmap=m) + fig, axes = plt.subplots(self.num_units, 1) + sw.plot_unit_templates(self._RX, self._SX, axes=axes, memmap=m) def test_unittemplatemaps(self): - sw.plot_unit_template_maps(self._RX, self._SX) + for m in memmaps: + sw.plot_unit_template_maps(self._RX, self._SX, memmap=m) def test_ampdist(self): sw.plot_amplitudes_distribution(self._RX, self._SX) @@ -53,9 +62,10 @@ def test_amptime(self): sw.plot_amplitudes_timeseries(self._RX, self._SX, axes=axes) def test_features(self): - sw.plot_pca_features(self._RX, self._SX) - fig, axes = plt.subplots(self.num_units, 1) - sw.plot_pca_features(self._RX, self._SX, axes=axes) + for m in memmaps: + sw.plot_pca_features(self._RX, self._SX, memap=m) + fig, axes = plt.subplots(self.num_units, 1) + sw.plot_pca_features(self._RX, self._SX, axes=axes, memap=m) def test_ach(self): sw.plot_autocorrelograms(self._SX, bin_size=1, window=10) diff --git a/spikewidgets/widgets/featurewidgets/pcawidget.py b/spikewidgets/widgets/featurewidgets/pcawidget.py index f166892..c5bb30a 100644 --- a/spikewidgets/widgets/featurewidgets/pcawidget.py +++ b/spikewidgets/widgets/featurewidgets/pcawidget.py @@ -5,7 +5,7 @@ def plot_pca_features(recording, sorting, unit_ids=None, max_spikes_per_unit=100, nproj=4, colormap=None, - figure=None, ax=None, axes=None): + figure=None, ax=None, axes=None, **pca_kwargs): """ Plots unit PCA features on best projections. @@ -30,6 +30,7 @@ def plot_pca_features(recording, sorting, unit_ids=None, max_spikes_per_unit=100 axes: list of matplotlib axes The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax and figure parameters are ignored + pca_kwargs: keyword arguments for st.postprocessing.compute_unit_pca_scores() Returns @@ -46,35 +47,30 @@ def plot_pca_features(recording, sorting, unit_ids=None, max_spikes_per_unit=100 colormap=colormap, figure=figure, ax=ax, - axes=axes + axes=axes, + **pca_kwargs ) W.plot() return W class PCAWidget(BaseMultiWidget): - def __init__(self, *, recording, sorting, unit_ids=None, max_spikes_per_unit=100, nproj=4, colormap=None, - figure=None, ax=None, axes=None, save_as_features=False, save_waveforms_as_features=False): + def __init__(self, *, recording, sorting, unit_ids=None, nproj=4, colormap=None, + figure=None, ax=None, axes=None, **pca_kwargs): BaseMultiWidget.__init__(self, figure, ax, axes) self._sorting = sorting self._recording = recording self._unit_ids = unit_ids - self._nproj = nproj - self._max_spikes_per_unit = max_spikes_per_unit self._pca_scores = None + self._nproj = nproj self._colormap = colormap - self._save_as_features = save_as_features - self._save_waveforms_as_features = save_waveforms_as_features + self._pca_kwargs = pca_kwargs self.name = 'Feature' def _compute_pca(self): self._pca_scores = st.postprocessing.compute_unit_pca_scores(recording=self._recording, sorting=self._sorting, - by_electrode=True, - max_spikes_per_unit=self._max_spikes_per_unit, - save_as_features=self._save_as_features, - save_waveforms_as_features= - self._save_waveforms_as_features) + **self._pca_kwargs) def plot(self): self._do_plot() @@ -106,8 +102,7 @@ def _do_plot(self): distances.append(dist) proj.append([ch1, pc1, ch2, pc2]) - list_best_proj = np.array( - proj)[np.argsort(distances)[::-1][:self._nproj]] + list_best_proj = np.array(proj)[np.argsort(distances)[::-1][:self._nproj]] self._plot_proj_multi(list_best_proj) def compute_cluster_average_distance(self, pc1, ch1, pc2, ch2): diff --git a/spikewidgets/widgets/mapswidget/activitymapwidget.py b/spikewidgets/widgets/mapswidget/activitymapwidget.py index 71f4da1..2edfca0 100644 --- a/spikewidgets/widgets/mapswidget/activitymapwidget.py +++ b/spikewidgets/widgets/mapswidget/activitymapwidget.py @@ -10,8 +10,8 @@ def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate', log=False, cmap='viridis', background='on', label_color='r', transpose=False, frame=False, colorbar=False, colorbar_bbox=None, - colorbar_orientation='vertical', colorbar_width=0.02, recompute_info=False, - ax=None, figure=None): + colorbar_orientation='vertical', colorbar_width=0.02, + ax=None, figure=None, **activity_kwargs): """ Plots spike rate (estimated using simple threshold detector) as 2D activity map. @@ -43,12 +43,11 @@ def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate', 'vertical' or 'horizontal' colorbar_width: float Width of colorbar in figure coordinates (default 0.02) - recompute_info: bool - If True, spike rates are recomputed figure: matplotlib figure The figure to be used. If not given a figure is created ax: matplotlib axis The axis to be used. If not given an axis is created + activity_kwargs: keyword arguments for st.postprocessing.compute_channel_spiking_activity() Returns ------- @@ -72,7 +71,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate', colorbar_bbox=colorbar_bbox, colorbar_orientation=colorbar_orientation, colorbar_width=colorbar_width, - recompute_info=recompute_info + **activity_kwargs ) W.plot() return W @@ -81,7 +80,7 @@ def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate', class ActivityMapWidget(BaseWidget): def __init__(self, recording, channel_ids, activity, log, trange, cmap, background, label_color='r', transpose=False, frame=False, colorbar=False, colorbar_bbox=None, colorbar_orientation='vertical', - colorbar_width=0.02, recompute_info=False, figure=None, ax=None): + colorbar_width=0.02, figure=None, ax=None, **activity_kwargs): BaseWidget.__init__(self, figure, ax) self._recording = recording self._channel_ids = channel_ids @@ -93,11 +92,11 @@ def __init__(self, recording, channel_ids, activity, log, trange, cmap, backgrou self._frame = frame self._bg = background self._label_color = label_color + self._activity_kwargs = activity_kwargs self._show_colorbar = colorbar self._colorbar_bbox = colorbar_bbox self._colorbar_orient = colorbar_orientation self._colorbar_width = colorbar_width - self._recompute_info = recompute_info self.colorbar = None self.name = 'ActivityMap' assert activity in ['rate', 'amplitude'], "'activity' can be either 'rate' or 'amplitude'" @@ -118,11 +117,7 @@ def _do_plot(self): spike_rates, spike_amplitudes = st.postprocessing.compute_channel_spiking_activity(self._recording, start_frame=self._trange[0], end_frame=self._trange[1], - method='detection', - align=False, - recompute_info= - self._recompute_info, - verbose=False) + **self._activity_kwargs) if self._transpose: locations = np.roll(locations, 1, axis=1) diff --git a/spikewidgets/widgets/mapswidget/templatemapswidget.py b/spikewidgets/widgets/mapswidget/templatemapswidget.py index 91f1430..9b4ad42 100644 --- a/spikewidgets/widgets/mapswidget/templatemapswidget.py +++ b/spikewidgets/widgets/mapswidget/templatemapswidget.py @@ -6,8 +6,8 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None, peak='neg', log=False, ncols=10, - ms_before=1., ms_after=2., max_spikes_per_unit=100, background='on', cmap='viridis', - label_color='r', figure=None, ax=None, axes=None): + background='on', cmap='viridis', label_color='r', figure=None, ax=None, axes=None, + **templates_kwargs): """ Plots sorting comparison confusion matrix. @@ -27,12 +27,6 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None, If True, log scale is used ncols: int Number of columns if multiple units are displayed - ms_before: float - Time before peak (ms) - ms_after: float - Time after peak (ms) - max_spikes_per_unit: int - Maximum number of spikes to display per unit. background: str 'on' or 'off' cmap: matplotlib colormap @@ -46,6 +40,8 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None, axes: list of matplotlib axes The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax and figure parameters are ignored + templates_kwargs: keyword arguments for st.postprocessing.get_unit_templates() + Returns ------- @@ -57,26 +53,24 @@ def plot_unit_template_maps(recording, sorting, channel_ids=None, unit_ids=None, sorting=sorting, channel_ids=channel_ids, unit_ids=unit_ids, - max_spikes_per_unit=max_spikes_per_unit, peak=peak, log=log, ncols=ncols, - ms_before=ms_before, - ms_after=ms_after, background=background, cmap=cmap, label_color=label_color, figure=figure, ax=ax, - axes=axes + axes=axes, + **templates_kwargs ) W.plot() return W class UnitTemplateMapsWidget(BaseMultiWidget): - def __init__(self, recording, sorting, channel_ids, unit_ids, peak, log, ncols, max_spikes_per_unit, ms_before, - ms_after, background, cmap, label_color='r', figure=None, ax=None, axes=None): + def __init__(self, recording, sorting, channel_ids, unit_ids, peak, log, ncols, background, cmap, label_color='r', + figure=None, ax=None, axes=None, **template_kwargs): BaseMultiWidget.__init__(self, figure, ax, axes) self._recording = recording self._sorting = sorting @@ -85,12 +79,10 @@ def __init__(self, recording, sorting, channel_ids, unit_ids, peak, log, ncols, self._peak = peak self._log = log self._ncols = ncols - self._ms_before = ms_before - self._ms_after = ms_after - self._max_spikes_per_unit = max_spikes_per_unit self._bg = background self._cmap = cmap self._label_color = label_color + self._template_kwargs = template_kwargs self.name = 'UnitTemplateMaps' assert 'location' in self._recording.get_shared_channel_property_names(), "Activity map requires 'location'" \ "property" @@ -102,8 +94,7 @@ def _do_plot(self): locations = self._recording.get_channel_locations(channel_ids=self._channel_ids) templates = st.postprocessing.get_unit_templates(self._recording, self._sorting, channel_ids=self._channel_ids, unit_ids=self._unit_ids, - max_spikes_per_unit=self._max_spikes_per_unit, - ms_before=self._ms_before, ms_after=self._ms_after) + **self._template_kwargs) if self._channel_ids is None: channel_ids = self._recording.get_channel_ids() else: diff --git a/spikewidgets/widgets/unitwaveformswidget/unitwaveformswidget.py b/spikewidgets/widgets/unitwaveformswidget/unitwaveformswidget.py index 2c68f06..8a9b80b 100644 --- a/spikewidgets/widgets/unitwaveformswidget/unitwaveformswidget.py +++ b/spikewidgets/widgets/unitwaveformswidget/unitwaveformswidget.py @@ -5,10 +5,9 @@ from spikewidgets.widgets.basewidget import BaseMultiWidget -def plot_unit_waveforms(recording, sorting, channel_ids=None, unit_ids=None, ms_before=1., ms_after=2., - max_spikes_per_unit=100, max_channels=16, channel_locs=True, radius=None, - plot_templates=True, show_all_channels=True, color='k', lw=2, axis_equal=False, - plot_channels=False, set_title=True, figure=None, ax=None, axes=None): +def plot_unit_waveforms(recording, sorting, channel_ids=None, unit_ids=None, channel_locs=True, radius=None, + max_channels=None, plot_templates=True, show_all_channels=True, color='k', lw=2, axis_equal=False, + plot_channels=False, set_title=True, figure=None, ax=None, axes=None, **waveforms_kwargs): """ Plots unit waveforms. @@ -22,12 +21,6 @@ def plot_unit_waveforms(recording, sorting, channel_ids=None, unit_ids=None, ms_ The channel ids to display unit_ids: list List of unit ids. - ms_before: float - Time before peak (ms) - ms_after: float - Time after peak (ms) - max_spikes_per_unit: int - Maximum number of spikes to display per unit max_channels: int Maximum number of largest channels to plot waveform channel_locs: bool @@ -57,6 +50,7 @@ def plot_unit_waveforms(recording, sorting, channel_ids=None, unit_ids=None, ms_ axes: list of matplotlib axes The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax and figure parameters are ignored + waveforms_kwargs: keyword arguments for st.postprocessing.get_unit_waveforms() Returns ------- @@ -68,10 +62,7 @@ def plot_unit_waveforms(recording, sorting, channel_ids=None, unit_ids=None, ms_ sorting=sorting, channel_ids=channel_ids, unit_ids=unit_ids, - max_spikes_per_unit=max_spikes_per_unit, max_channels=max_channels, - ms_before=ms_before, - ms_after=ms_after, channel_locs=channel_locs, plot_templates=plot_templates, figure=figure, @@ -83,16 +74,16 @@ def plot_unit_waveforms(recording, sorting, channel_ids=None, unit_ids=None, ms_ lw=lw, axis_equal=axis_equal, plot_channels=plot_channels, - set_title=set_title + set_title=set_title, + **waveforms_kwargs ) W.plot() return W -def plot_unit_templates(recording, sorting, channel_ids=None, unit_ids=None, ms_before=1., ms_after=2., - max_spikes_per_unit=100, max_channels=16, channel_locs=True, radius=None, - show_all_channels=True, color='k', lw=2, axis_equal=False, - plot_channels=False, set_title=True, figure=None, ax=None, axes=None): +def plot_unit_templates(recording, sorting, channel_ids=None, unit_ids=None, max_channels=None, channel_locs=True, + radius=None, show_all_channels=True, color='k', lw=2, axis_equal=False, + plot_channels=False, set_title=True, figure=None, ax=None, axes=None, **template_kwargs): """ Plots unit waveforms. @@ -106,12 +97,6 @@ def plot_unit_templates(recording, sorting, channel_ids=None, unit_ids=None, ms_ The channel ids to display unit_ids: list List of unit ids. - ms_before: float - Time before peak (ms) - ms_after: float - Time after peak (ms) - max_spikes_per_unit: int - Maximum number of spikes to display per unit max_channels: int Maximum number of largest channels to plot waveform channel_locs: bool @@ -139,6 +124,7 @@ def plot_unit_templates(recording, sorting, channel_ids=None, unit_ids=None, ms_ axes: list of matplotlib axes The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax and figure parameters are ignored + template_kwargs: keyword arguments for st.postprocessing.get_unit_templates() Returns @@ -151,10 +137,7 @@ def plot_unit_templates(recording, sorting, channel_ids=None, unit_ids=None, ms_ sorting=sorting, channel_ids=channel_ids, unit_ids=unit_ids, - max_spikes_per_unit=max_spikes_per_unit, max_channels=max_channels, - ms_before=ms_before, - ms_after=ms_after, channel_locs=channel_locs, figure=figure, ax=ax, @@ -165,37 +148,40 @@ def plot_unit_templates(recording, sorting, channel_ids=None, unit_ids=None, ms_ lw=lw, axis_equal=axis_equal, plot_channels=plot_channels, - set_title=set_title + set_title=set_title, + **template_kwargs ) W.plot() return W class UnitWaveformsWidget(BaseMultiWidget): - def __init__(self, *, recording, sorting, channel_ids=None, unit_ids=None, max_spikes_per_unit=50, - max_channels=16, ms_before=1., ms_after=2., channel_locs=True, plot_templates=True, radius=None, + def __init__(self, *, recording, sorting, channel_ids=None, unit_ids=None, max_channels=None, + channel_locs=True, plot_waveforms=True, + plot_templates=True, radius=None, show_all_channels=True, figure=None, ax=None, axes=None, color='k', lw=2, axis_equal=False, - plot_channels=False, set_title=True): + plot_channels=False, set_title=True, **kwargs): BaseMultiWidget.__init__(self, figure, ax, axes) self._recording = recording self._sorting = sorting self._channel_ids = channel_ids + if max_channels is None: + max_channels = recording.get_num_channels() self._max_channels = max_channels self._unit_ids = unit_ids - self._ms_before = ms_before - self._ms_after = ms_after - self._max_spikes_per_unit = max_spikes_per_unit self._ch_locs = channel_locs self.name = 'UnitWaveforms' - self._plot_waveforms = True + self._plot_waveforms = plot_waveforms self._plot_templates = plot_templates self._radius = radius self._show_all_channels = show_all_channels - self._color=color - self._lw=lw - self._axis_equal=axis_equal - self._plot_channels=plot_channels + self._color = color + self._lw = lw + self._axis_equal = axis_equal + self._plot_channels = plot_channels self._set_title = set_title + self._kwargs = kwargs + print(self._kwargs) def plot(self): self._do_plot() @@ -209,7 +195,7 @@ def _do_plot(self): channel_ids = self._recording.get_channel_ids() if 'location' in self._recording.get_shared_channel_property_names(): all_locations = np.array(self._recording.get_channel_locations()) - channel_locations = np.array(self._recording.get_channel_locations(channel_ids=channel_ids)) + channel_locations = np.array(self._recording.get_channel_locations(channel_ids=channel_ids)) else: all_locations = None channel_locations = None @@ -218,14 +204,18 @@ def _do_plot(self): for unit_id in unit_ids: spiketrain = self._sorting.get_unit_spike_train(unit_id=unit_id) if spiketrain is not None: - random_wf = st.postprocessing.get_unit_waveforms(recording=self._recording, sorting=self._sorting, - unit_ids=[unit_id], channel_ids=channel_ids, - ms_before=self._ms_before, ms_after=self._ms_after, - max_spikes_per_unit=self._max_spikes_per_unit, - save_as_features=False)[0] + if self._plot_waveforms: + random_wf = st.postprocessing.get_unit_waveforms(recording=self._recording, sorting=self._sorting, + unit_ids=[unit_id], channel_ids=channel_ids, + **self._kwargs)[0] + else: + random_wf = None + template = st.postprocessing.get_unit_templates(recording=self._recording, + sorting=self._sorting, + unit_ids=[unit_id], channel_ids=channel_ids, + **self._kwargs)[0] if self._radius is None: - if self._max_channels < self._recording.get_num_channels(): - template = np.mean(random_wf, axis=0) + if self._max_channels < template.shape[0]: peak_idx = np.unravel_index(np.argmax(np.abs(template)), template.shape)[1] max_channel_idxs = np.argsort(np.abs(template[:, peak_idx]))[::-1][:self._max_channels] @@ -236,22 +226,23 @@ def _do_plot(self): else: max_channels_list.append(np.arange(len(channel_ids))) else: - template = np.mean(random_wf, axis=0) peak_idx = np.unravel_index(np.argmax(np.abs(template)), template.shape)[1] max_channel_idx = np.argsort(np.abs(template[:, peak_idx]))[::-1][0] - c = all_locations[max_channel_idx] - d = np.sqrt(np.sum((channel_locations-c)**2, axis=1)) - max_channels_list.append(np.array(channel_ids)[np.where(d<=self._radius)[0]]) + center_location = all_locations[max_channel_idx] + dists = np.array([np.linalg.norm(loc - center_location) for loc in all_locations]) + max_channels_list.append(np.where(dists <= self._radius)[0]) spikes = random_wf if self._set_title: item = dict( representative_waveforms=spikes, + average_waveform=template, title='Unit {}'.format(int(unit_id)) ) else: item = dict( representative_waveforms=spikes, + average_waveform=template, title='' ) list_spikes.append(item) @@ -262,12 +253,14 @@ def _do_plot(self): if self._ch_locs: self._plot_spike_shapes_multi(list_spikes, channel_locations=channel_locations, all_locations=all_locations, max_channels_list=max_channels_list, - plot_templates=self._plot_templates, plot_waveforms=self._plot_waveforms, - show_all_channels=self._show_all_channels, plot_channels=self._plot_channels) + plot_templates=self._plot_templates, plot_waveforms=self._plot_waveforms, + show_all_channels=self._show_all_channels, + plot_channels=self._plot_channels) else: self._plot_spike_shapes_multi(list_spikes, channel_locations=None, max_channels_list=max_channels_list, - plot_templates=self._plot_templates, plot_waveforms=self._plot_waveforms, - show_all_channels=self._show_all_channels, plot_channels=self._plot_channels) + plot_templates=self._plot_templates, plot_waveforms=self._plot_waveforms, + show_all_channels=self._show_all_channels, + plot_channels=self._plot_channels) def _plot_spike_shapes_multi(self, list_spikes, max_channels_list, *, ncols=5, **kwargs): vscale, ylim = _determine_global_vscale_ylim(list_spikes) @@ -277,32 +270,33 @@ def _plot_spike_shapes_multi(self, list_spikes, max_channels_list, *, ncols=5, * for i, item in enumerate(list_spikes): ax = self.get_tiled_ax(i, nrows, ncols) _plot_spike_shapes(**item, **kwargs, ax=ax, channels=max_channels_list[i], vscale=vscale, ylim_wf=ylim, - color=self._color, lw=self._lw) + color=self._color, lw=self._lw) class UnitTemplatesWidget(UnitWaveformsWidget): - def __init__(self, *, recording, sorting, channel_ids=None, unit_ids=None, max_spikes_per_unit=50, - max_channels=16, ms_before=1., ms_after=2., channel_locs=True, figure=None, show_all_channels=True, + def __init__(self, *, recording, sorting, channel_ids=None, unit_ids=None, + max_channels=None, channel_locs=True, figure=None, show_all_channels=True, ax=None, axes=None, radius=None, color='k', lw=2, axis_equal=False, plot_channels=False, - set_title=True): + set_title=True, **template_kwargs): UnitWaveformsWidget.__init__(self, recording=recording, sorting=sorting, channel_ids=channel_ids, - unit_ids=unit_ids, max_spikes_per_unit=max_spikes_per_unit, - max_channels=max_channels, ms_before=ms_before, ms_after=ms_after, - channel_locs=channel_locs, plot_templates=True, figure=figure, ax=ax, axes=axes, + unit_ids=unit_ids, max_channels=max_channels, + channel_locs=channel_locs, plot_waveforms=False, + plot_templates=True, figure=figure, ax=ax, axes=axes, radius=radius, show_all_channels=show_all_channels, color=color, lw=lw, - axis_equal=axis_equal, plot_channels=plot_channels, set_title=set_title) + axis_equal=axis_equal, plot_channels=plot_channels, set_title=set_title, + **template_kwargs) self.name = 'UnitTemplates' - self._plot_waveforms = False -def _plot_spike_shapes(*, ax, channels, representative_waveforms=None, channel_locations=None, - all_locations=None, color='blue', vscale=None, plot_waveforms=True, plot_templates=True, - ylim_wf=None, title='', show_all_channels=True, lw=2, axis_equal=False, plot_channels=False): - if representative_waveforms is None: +def _plot_spike_shapes(*, ax, channels, representative_waveforms=None, average_waveform=None, + channel_locations=None, all_locations=None, color='blue', vscale=None, plot_waveforms=True, + plot_templates=True, ylim_wf=None, title='', show_all_channels=True, lw=2, axis_equal=False, + plot_channels=False): + if representative_waveforms is None and average_waveform is None: raise Exception('You must provide either average_waveform, representative waveforms, or both') - + waveforms = representative_waveforms - average_waveform = np.mean(waveforms, axis=0) + average_waveform = average_waveform M = average_waveform.shape[0] # number of channels if channel_locations is None: channel_locations = np.zeros((M, 2)) diff --git a/spikewidgets/widgets/utils.py b/spikewidgets/widgets/utils.py index 35e8756..773d926 100644 --- a/spikewidgets/widgets/utils.py +++ b/spikewidgets/widgets/utils.py @@ -31,7 +31,6 @@ def on_press(self, event): self.press = x0, y0, event.xdata, event.ydata LabeledRectangle.lock = self self.text.set_visible(True) - self.text.draw() def on_release(self, event): 'on release we reset the press data' @@ -40,7 +39,6 @@ def on_release(self, event): self.press = None LabeledRectangle.lock = None self.text.set_visible(False) - self.text.draw() def disconnect(self): 'disconnect all the stored connection ids' @@ -78,7 +76,6 @@ def on_press(self, event): self.press = x0, y0, event.xdata, event.ydata LabeledEllipse.lock = self self.text.set_visible(True) - self.text.draw() def on_release(self, event): 'on release we reset the press data' @@ -87,7 +84,6 @@ def on_release(self, event): self.press = None LabeledEllipse.lock = None self.text.set_visible(False) - self.text.draw() def disconnect(self): 'disconnect all the stored connection ids'