Skip to content

Commit

Permalink
Merge pull request #2646 from alejoe91/fix-scaling-spikes-on-traces
Browse files Browse the repository at this point in the history
Improve spikes on traces
  • Loading branch information
samuelgarcia authored Apr 12, 2024
2 parents 765cb67 + 371bf06 commit 482bd74
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 24 deletions.
31 changes: 17 additions & 14 deletions src/spikeinterface/widgets/spikes_on_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class SpikesOnTracesWidget(BaseWidget):
clim : None, tuple or dict, default: None
When mode is "map", this argument controls color limits.
If dict, keys should be the same as recording keys
scale : float, default: 1
Scale factor for the traces
with_colorbar : bool, default: True
When mode is "map", a colorbar is added
tile_size : int, default: 512
Expand Down Expand Up @@ -79,6 +81,9 @@ def __init__(
clim=None,
tile_size=512,
seconds_per_row=0.2,
scale=1,
spike_width_ms=4,
spike_height_um=20,
with_colorbar=True,
backend=None,
**backend_kwargs,
Expand All @@ -87,6 +92,7 @@ def __init__(
self.check_extensions(sorting_analyzer, "unit_locations")

sorting: BaseSorting = sorting_analyzer.sorting
recording: BaseRecording = sorting_analyzer.recording

if unit_ids is None:
unit_ids = sorting.get_unit_ids()
Expand All @@ -112,7 +118,6 @@ def __init__(
assert isinstance(sparsity, ChannelSparsity)

unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit")

options = dict(
segment_index=segment_index,
channel_ids=channel_ids,
Expand All @@ -127,6 +132,7 @@ def __init__(
clim=clim,
tile_size=tile_size,
with_colorbar=with_colorbar,
scale=scale,
)

plot_data = dict(
Expand All @@ -136,6 +142,8 @@ def __init__(
sparsity=sparsity,
unit_colors=unit_colors,
unit_locations=unit_locations,
spike_width_ms=spike_width_ms,
spike_height_um=spike_height_um,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
Expand All @@ -162,10 +170,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

frame_range = traces_widget.data_plot["frame_range"]
segment_index = traces_widget.data_plot["segment_index"]
min_y = np.min(traces_widget.data_plot["channel_locations"][:, 1])
max_y = np.max(traces_widget.data_plot["channel_locations"][:, 1])

n = len(traces_widget.data_plot["channel_ids"])

if ax.get_legend() is not None:
ax.get_legend().remove()
Expand All @@ -186,9 +190,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
spike_times_to_plot = sorting.get_unit_spike_train(
unit, segment_index=segment_index, return_times=True
)[spike_start:spike_end]
width = dp.spike_width_ms / 1000
height = dp.spike_height_um
unit_y_loc = dp.unit_locations[unit][1]
width = 2 * 1e-3
ellipse_kwargs = dict(width=width, height=10, fc="none", ec=dp.unit_colors[unit], lw=2)
ellipse_kwargs = dict(width=width, height=height, fc="none", ec=dp.unit_colors[unit], lw=2)
patches = [Ellipse((s, unit_y_loc), **ellipse_kwargs) for s in spike_times_to_plot]
for p in patches:
ax.add_patch(p)
Expand All @@ -210,11 +215,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
label_set = False
if len(spike_frames_to_plot) > 0:
vspacing = traces_widget.data_plot["vspacing"]
traces = traces_widget.data_plot["list_traces"][0]
traces = traces_widget.data_plot["list_traces"][0] * dp.options["scale"]

# TODO find a better way
nbefore = 30
nafter = 60
nbefore = nafter = int(dp.spike_width_ms / 2 * sorting_analyzer.sampling_frequency / 1000)
waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-nbefore, nafter) - frame_range[0]
waveform_idxs = np.clip(waveform_idxs, 0, len(traces_widget.data_plot["times"]) - 1)

Expand All @@ -223,7 +226,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
# discontinuity
times[:, -1] = np.nan
times_r = times.reshape(times.shape[0] * times.shape[1])
waveforms = traces[waveform_idxs]
waveforms = traces[waveform_idxs] * dp.options["scale"]
waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2]))

for i, chan_id in enumerate(traces_widget.data_plot["channel_ids"]):
Expand All @@ -234,7 +237,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
handles.append(l[0])
labels.append(unit)
label_set = True
# ax.legend(handles, labels)

def plot_ipywidgets(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -295,6 +297,7 @@ def _update_ipywidget(self, change=None):

unit_ids = self.unit_selector.value
start_frame, end_frame, segment_index = self._traces_widget.time_slider.value
scale = self._traces_widget.scaler.value
channel_ids = self._traces_widget.channel_selector.value
mode = self._traces_widget.mode_selector.value

Expand All @@ -304,7 +307,7 @@ def _update_ipywidget(self, change=None):
dict(
channel_ids=channel_ids,
segment_index=segment_index,
# frame_range=(start_frame, end_frame),
scale=scale,
time_range=np.array([start_frame, end_frame]) / self.sampling_frequency,
mode=mode,
with_colorbar=False,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def test_plot_multicomparison(self):
# mytest.test_plot_traces()
# mytest.test_plot_spikes_on_traces()
# mytest.test_plot_unit_waveforms()
mytest.test_plot_unit_templates()
mytest.test_plot_spikes_on_traces()
# mytest.test_plot_unit_depths()
# mytest.test_plot_autocorrelograms()
# mytest.test_plot_crosscorrelograms()
Expand Down
25 changes: 17 additions & 8 deletions src/spikeinterface/widgets/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class TracesWidget(BaseWidget):
clim: None, tuple or dict, default: None
When mode is "map", this argument controls color limits.
If dict, keys should be the same as recording keys
scale: float, default: 1
Scale factor for the traces
with_colorbar: bool, default: True
When mode is "map", a colorbar is added
tile_size: int, default: 1500
Expand All @@ -70,6 +72,7 @@ def __init__(
clim=None,
tile_size=1500,
seconds_per_row=0.2,
scale=1,
with_colorbar=True,
add_legend=True,
backend=None,
Expand All @@ -91,12 +94,7 @@ def __init__(
f"is currently of type {type(recording)}"
)

if rec0.has_channel_location():
channel_locations = rec0.get_channel_locations()
else:
channel_locations = None

if order_channel_by_depth and channel_locations is not None:
if order_channel_by_depth and rec0.has_channel_location():
from ..preprocessing import depth_order

rec0 = depth_order(rec0)
Expand All @@ -111,6 +109,11 @@ def __init__(
if channel_ids is None:
channel_ids = rec0.channel_ids

if rec0.has_channel_location():
channel_locations = rec0.get_channel_locations()
else:
channel_locations = None

layer_keys = list(recordings.keys())

if segment_index is None:
Expand Down Expand Up @@ -142,6 +145,8 @@ def __init__(
recordings, channel_ids, time_range, segment_index, return_scaled=return_scaled
)

list_traces = [traces * scale for traces in list_traces]

# stat for auto scaling done on the first layer
traces0 = list_traces[0]
mean_channel_std = np.mean(np.std(traces0, axis=0))
Expand Down Expand Up @@ -237,8 +242,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

ax = self.ax
n = len(dp.channel_ids)
rec0 = dp.recordings[list(dp.recordings.keys())[0]]
channel_indices = rec0.ids_to_indices(dp.channel_ids)

if dp.channel_locations is not None:
y_locs = dp.channel_locations[:, 1]
y_locs = dp.channel_locations[channel_indices, 1]
else:
y_locs = np.arange(n)
min_y = np.min(y_locs)
Expand Down Expand Up @@ -287,7 +295,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids])
ax.set_yticklabels(channel_labels)
else:
ax.get_yaxis().set_visible(False)
ax.set_yticks([min_y, max_y])
ax.set_yticklabels([min_y, max_y])

def plot_ipywidgets(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/utils_ipywidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(self, value=1.0, factor=1.2, **kwargs):
assert factor > 1.0
self.factor = factor

self.scale_label = W.Label("Scale", layout=W.Layout(layout=W.Layout(width="95%"), justify_content="center"))
self.scale_label = W.Label("Scale", layout=W.Layout(width="95%", justify_content="center"))

self.plus_selector = W.Button(
description="",
Expand Down

0 comments on commit 482bd74

Please sign in to comment.