Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #62 from SpikeInterface/mhh
Browse files Browse the repository at this point in the history
Various widget improvements
  • Loading branch information
alejoe91 authored Jun 20, 2020
2 parents c73d460 + 6fc49ed commit 0a44d74
Show file tree
Hide file tree
Showing 11 changed files with 335 additions and 127 deletions.
20 changes: 20 additions & 0 deletions spikewidgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import spikeextractors as se
import spikewidgets as sw
import spikecomparison as sc
import matplotlib.pyplot as plt
import unittest


class TestWidgets(unittest.TestCase):
def setUp(self):
self._RX, self._SX = se.example_datasets.toy_example(num_channels=4, duration=10)
self.num_units = len(self._SX.get_unit_ids())

def tearDown(self):
pass
Expand All @@ -25,30 +27,46 @@ def test_geometry(self):

def test_unitwaveforms(self):
sw.plot_unit_waveforms(self._RX, self._SX)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_unit_waveforms(self._RX, self._SX, axes=axes)

def test_unittemplates(self):
sw.plot_unit_templates(self._RX, self._SX)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_unit_templates(self._RX, self._SX, axes=axes)

def test_unittemplatemaps(self):
sw.plot_unit_template_maps(self._RX, self._SX)

def test_ampdist(self):
sw.plot_amplitudes_distribution(self._RX, self._SX)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_amplitudes_distribution(self._RX, self._SX, axes=axes)

def test_amptime(self):
sw.plot_amplitudes_timeseries(self._RX, self._SX)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_amplitudes_timeseries(self._RX, self._SX, axes=axes)

def test_features(self):
sw.plot_pca_features(self._RX, self._SX)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_pca_features(self._RX, self._SX, axes=axes)

def test_ach(self):
sw.plot_autocorrelograms(self._SX, bin_size=1, window=10)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_autocorrelograms(self._SX, axes=axes)

def test_cch(self):
sw.plot_crosscorrelograms(self._SX, bin_size=1, window=10)
fig, axes = plt.subplots(self.num_units, self.num_units) # for cch need square matrix
sw.plot_crosscorrelograms(self._SX, axes=axes)

def test_isi(self):
sw.plot_isi_distribution(self._SX, bins=10, window=1)
fig, axes = plt.subplots(self.num_units, 1)
sw.plot_isi_distribution(self._SX, axes=axes)

def test_rasters(self):
sw.plot_rasters(self._SX)
Expand All @@ -70,6 +88,8 @@ def test_multicomp_graph(self):
sw.plot_multicomp_graph(msc, edge_cmap='viridis', node_cmap='rainbow', draw_labels=False)
sw.plot_multicomp_agreement(msc)
sw.plot_multicomp_agreement_by_sorter(msc)
fig, axes = plt.subplots(len(msc.sorting_list), 1)
sw.plot_multicomp_agreement_by_sorter(msc, axes=axes)


if __name__ == '__main__':
Expand Down
30 changes: 19 additions & 11 deletions spikewidgets/widgets/amplitudewidget/amplitudewidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def plot_amplitudes_distribution(recording, sorting, unit_ids=None, max_spikes_per_unit=100,
figure=None, ax=None):
figure=None, ax=None, axes=None):
"""
Plots waveform amplitudes distribution.
Expand All @@ -17,11 +17,14 @@ def plot_amplitudes_distribution(recording, sorting, unit_ids=None, max_spikes_p
unit_ids: list
List of unit ids
max_spikes_per_unit: int
Maximum number of spikes to display per unit.
Maximum number of spikes to display per unit
figure: matplotlib figure
The figure to be used. If not given a figure is created
ax: matplotlib axis
The axis to be used. If not given an axis is created
axes: list of matplotlib axes
The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax
and figure parameters are ignored
Returns
-------
Expand All @@ -34,14 +37,15 @@ def plot_amplitudes_distribution(recording, sorting, unit_ids=None, max_spikes_p
unit_ids=unit_ids,
max_spikes_per_unit=max_spikes_per_unit,
figure=figure,
ax=ax
ax=ax,
axes=axes
)
W.plot()
return W


def plot_amplitudes_timeseries(recording, sorting, unit_ids=None, max_spikes_per_unit=100,
figure=None, ax=None):
figure=None, ax=None, axes=None):
"""
Plots waveform amplitudes timeseries.
Expand All @@ -59,6 +63,9 @@ def plot_amplitudes_timeseries(recording, sorting, unit_ids=None, max_spikes_per
The figure to be used. If not given a figure is created
ax: matplotlib axis
The axis to be used. If not given an axis is created
axes: list of matplotlib axes
The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax
and figure parameters are ignored
Returns
-------
Expand All @@ -71,15 +78,16 @@ def plot_amplitudes_timeseries(recording, sorting, unit_ids=None, max_spikes_per
unit_ids=unit_ids,
max_spikes_per_unit=max_spikes_per_unit,
figure=figure,
ax=ax
ax=ax,
axes=axes
)
W.plot()
return W


class AmplitudeBaseWidget(BaseMultiWidget):
def __init__(self, recording, sorting, max_spikes_per_unit=100, figure=None, ax=None):
BaseMultiWidget.__init__(self, figure, ax)
def __init__(self, recording, sorting, max_spikes_per_unit=100, figure=None, ax=None, axes=None):
BaseMultiWidget.__init__(self, figure, ax, axes)
self._sorting = sorting
self._recording = recording
self._max_spikes_per_unit = max_spikes_per_unit
Expand All @@ -100,8 +108,8 @@ def compute_amps(self, *, times):


class AmplitudeTimeseriesWidget(AmplitudeBaseWidget):
def __init__(self, *, recording, sorting, unit_ids=None, max_spikes_per_unit=100, figure=None, ax=None):
AmplitudeBaseWidget.__init__(self, recording, sorting, max_spikes_per_unit, figure, ax)
def __init__(self, *, recording, sorting, unit_ids=None, max_spikes_per_unit=100, figure=None, ax=None, axes=None):
AmplitudeBaseWidget.__init__(self, recording, sorting, max_spikes_per_unit, figure, ax, axes)
self._unit_ids = unit_ids
self.name = 'AmplitudeTimeseries'

Expand Down Expand Up @@ -142,8 +150,8 @@ def _plot_amp_time_multi(self, list_isi, *, ncols=5, ylim=None, **kwargs):


class AmplitudeDistributionWidget(AmplitudeBaseWidget):
def __init__(self, *, recording, sorting, unit_ids=None, max_spikes_per_unit=100, figure=None, ax=None):
AmplitudeBaseWidget.__init__(self, recording, sorting, max_spikes_per_unit, figure, ax)
def __init__(self, *, recording, sorting, unit_ids=None, max_spikes_per_unit=100, figure=None, ax=None, axes=None):
AmplitudeBaseWidget.__init__(self, recording, sorting, max_spikes_per_unit, figure, ax, axes)
self._unit_ids = unit_ids
self.name = 'AmplitudeDistribution'

Expand Down
77 changes: 55 additions & 22 deletions spikewidgets/widgets/basewidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,35 +27,68 @@ def get_name(self):


class BaseMultiWidget:
def __init__(self, figure=None, ax=None):
if figure is None and ax is None:
def __init__(self, figure=None, ax=None, axes=None):
self._use_gs = True
self._gs = None
self.axes = []
if figure is None and ax is None and axes is None:
self.figure = plt.figure()
self.ax = self.figure.add_subplot(111)
elif ax is None:
elif ax is None and axes is None:
self.figure = figure
self.ax = self.figure.add_subplot(111)
else:
elif axes is None:
self.figure = ax.get_figure()
self.ax = ax
self.axes = []
self._gs = None
self.ax.axis('off')
if axes is not None:
assert len(axes) > 1, "'axes' should be a list with more than one axis"
self.axes = axes
self.axes = np.array(self.axes)
assert self.axes.ndim == 2 or self.axes.ndim == 1, "'axes' can be a 1-d array or list or a 2d array of axis"
if self.axes.ndim == 1:
self.figure = self.axes[0].get_figure()
else:
self.figure = self.axes[0, 0].get_figure()
self._use_gs = False
else:
self.ax.axis('off')
self.name = None

def get_tiled_ax(self, i, nrows, ncols, hspace=0.3, wspace=0.3, is_diag=False):
if self._gs is None:
self._gs = gridspec.GridSpecFromSubplotSpec(int(nrows), int(ncols), subplot_spec=self.ax,
hspace=hspace, wspace=wspace)
r = int(i // ncols)
c = int(np.mod(i, ncols))
gs_sel = self._gs[r, c]
ax = self.figure.add_subplot(gs_sel)
self.axes.append(ax)
if r == c:
diag = True
else:
diag = False
if is_diag:
return ax, diag
if self._use_gs:
if self._gs is None:
self._gs = gridspec.GridSpecFromSubplotSpec(int(nrows), int(ncols), subplot_spec=self.ax,
hspace=hspace, wspace=wspace)
r = int(i // ncols)
c = int(np.mod(i, ncols))
gs_sel = self._gs[r, c]
ax = self.figure.add_subplot(gs_sel)
self.axes.append(ax)
if r == c:
diag = True
else:
diag = False
if is_diag:
return ax, diag
else:
return ax
else:
return ax
if np.array(self.axes).ndim == 1:
assert i < len(self.axes), f"{i} exceeds the number of available axis"
if is_diag:
return self.axes[i], False
else:
return self.axes[i]
else:
nrows = self.axes.shape[0]
ncols = self.axes.shape[1]
r = int(i // ncols)
c = int(np.mod(i, ncols))
if r == c:
diag = True
else:
diag = False
if is_diag:
return self.axes[r, c], diag
else:
return self.axes[r, c]
29 changes: 20 additions & 9 deletions spikewidgets/widgets/correlogramswidget/correlogramswidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def plot_autocorrelograms(sorting, sampling_frequency=None, unit_ids=None, bin_size=2, window=50,
figure=None, ax=None):
figure=None, ax=None, axes=None):
"""
Plots spike train auto-correlograms.
Expand All @@ -25,6 +25,9 @@ def plot_autocorrelograms(sorting, sampling_frequency=None, unit_ids=None, bin_s
The figure to be used. If not given a figure is created
ax: matplotlib axis
The axis to be used. If not given an axis is created
axes: list of matplotlib axes
The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax
and figure parameters are ignored
Returns
-------
Expand All @@ -44,14 +47,15 @@ def plot_autocorrelograms(sorting, sampling_frequency=None, unit_ids=None, bin_s
binsize=bin_size,
window=window,
figure=figure,
ax=ax
ax=ax,
axes=axes
)
W.plot()
return W


def plot_crosscorrelograms(sorting, sampling_frequency=None, unit_ids=None, bin_size=1, window=10,
figure=None, ax=None):
figure=None, ax=None, axes=None):
"""
Plots spike train cross-correlograms.
Expand All @@ -71,6 +75,9 @@ def plot_crosscorrelograms(sorting, sampling_frequency=None, unit_ids=None, bin_
The figure to be used. If not given a figure is created
ax: matplotlib axis
The axis to be used. If not given an axis is created
axes: list of matplotlib axes
The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax
and figure parameters are ignored
Returns
-------
Expand All @@ -90,15 +97,17 @@ def plot_crosscorrelograms(sorting, sampling_frequency=None, unit_ids=None, bin_
binsize=bin_size,
window=window,
figure=figure,
ax=ax
ax=ax,
axes=axes
)
W.plot()
return W


class AutoCorrelogramsWidget(BaseMultiWidget):
def __init__(self, *, sorting, sampling_frequency, unit_ids=None, binsize=2, window=50, figure=None, ax=None):
BaseMultiWidget.__init__(self, figure, ax)
def __init__(self, *, sorting, sampling_frequency, unit_ids=None, binsize=2, window=50, figure=None, ax=None,
axes=None):
BaseMultiWidget.__init__(self, figure, ax=ax, axes=axes)
self._sorting = sorting
self._unit_ids = unit_ids
self._sampling_frequency = sampling_frequency
Expand Down Expand Up @@ -144,8 +153,9 @@ def _plot_autocorrelograms_multi(self, list_corr, *, ncols=5, **kwargs):


class CrossCorrelogramsWidget(BaseMultiWidget):
def __init__(self, *, sorting, sampling_frequency, unit_ids=None, binsize=2, window=50, figure=None, ax=None):
BaseMultiWidget.__init__(self, figure, ax)
def __init__(self, *, sorting, sampling_frequency, unit_ids=None, binsize=2, window=50, figure=None, ax=None,
axes=None):
BaseMultiWidget.__init__(self, figure, ax, axes)
self._sorting = sorting
self._unit_ids = unit_ids
self._sampling_frequency = sampling_frequency
Expand Down Expand Up @@ -191,7 +201,8 @@ def _plot_crosscorrelograms_multi(self, list_corr, **kwargs):
units = self._sorting.get_unit_ids()
ncols = len(units)
nrows = np.ceil(len(list_corr) / ncols)
self.figure.set_size_inches((3*ncols, 2*nrows))
if self._use_gs:
self.figure.set_size_inches((3*ncols, 2*nrows))
for i, item in enumerate(list_corr):
ax, diag = self.get_tiled_ax(i, nrows, ncols, hspace=1.5, wspace=0.2, is_diag=True)
if diag:
Expand Down
Loading

0 comments on commit 0a44d74

Please sign in to comment.