Skip to content

Commit

Permalink
Merge pull request #2519 from samuelgarcia/analyzer
Browse files Browse the repository at this point in the history
black
  • Loading branch information
samuelgarcia authored Feb 27, 2024
2 parents 486ef44 + ae274db commit 5cd00e4
Show file tree
Hide file tree
Showing 41 changed files with 780 additions and 705 deletions.
93 changes: 57 additions & 36 deletions examples/how_to/analyse_neuropixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from pathlib import Path

# +
base_folder = Path('/mnt/data/sam/DataSpikeSorting/neuropixel_example/')
base_folder = Path("/mnt/data/sam/DataSpikeSorting/neuropixel_example/")

spikeglx_folder = base_folder / 'Rec_1_10_11_2021_g0'
spikeglx_folder = base_folder / "Rec_1_10_11_2021_g0"

# -

Expand All @@ -40,11 +40,11 @@
# We need to specify which one to read:
#

stream_names, stream_ids = si.get_neo_streams('spikeglx', spikeglx_folder)
stream_names, stream_ids = si.get_neo_streams("spikeglx", spikeglx_folder)
stream_names

# we do not load the sync channel, so the probe is automatically loaded
raw_rec = si.read_spikeglx(spikeglx_folder, stream_name='imec0.ap', load_sync_channel=False)
raw_rec = si.read_spikeglx(spikeglx_folder, stream_name="imec0.ap", load_sync_channel=False)
raw_rec

# we automatically have the probe loaded!
Expand All @@ -63,10 +63,10 @@
#

# +
rec1 = si.highpass_filter(raw_rec, freq_min=400.)
rec1 = si.highpass_filter(raw_rec, freq_min=400.0)
bad_channel_ids, channel_labels = si.detect_bad_channels(rec1)
rec2 = rec1.remove_channels(bad_channel_ids)
print('bad_channel_ids', bad_channel_ids)
print("bad_channel_ids", bad_channel_ids)

rec3 = si.phase_shift(rec2)
rec4 = si.common_reference(rec3, operator="median", reference="global")
Expand Down Expand Up @@ -94,17 +94,23 @@
# here we use a static plot using matplotlib backend
fig, axs = plt.subplots(ncols=3, figsize=(20, 10))

si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0])
si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1])
si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2])
for i, label in enumerate(('filter', 'cmr', 'final')):
si.plot_traces(rec1, backend="matplotlib", clim=(-50, 50), ax=axs[0])
si.plot_traces(rec4, backend="matplotlib", clim=(-50, 50), ax=axs[1])
si.plot_traces(rec, backend="matplotlib", clim=(-50, 50), ax=axs[2])
for i, label in enumerate(("filter", "cmr", "final")):
axs[i].set_title(label)
# -

# plot some channels
fig, ax = plt.subplots(figsize=(20, 10))
some_chans = rec.channel_ids[[100, 150, 200, ]]
si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans)
some_chans = rec.channel_ids[
[
100,
150,
200,
]
]
si.plot_traces({"filter": rec1, "cmr": rec4}, backend="matplotlib", mode="line", ax=ax, channel_ids=some_chans)


# ### Should we save the preprocessed data to a binary file?
Expand All @@ -118,9 +124,9 @@
# Depending on the complexity of the preprocessing chain, this operation can take a while. However, we can make use of the powerful parallelization mechanism of SpikeInterface.

# +
job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)
job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True)

rec = rec.save(folder=base_folder / 'preprocess', format='binary', **job_kwargs)
rec = rec.save(folder=base_folder / "preprocess", format="binary", **job_kwargs)
# -

# our recording now points to the new binary folder
Expand Down Expand Up @@ -149,7 +155,7 @@

fig, ax = plt.subplots()
_ = ax.hist(noise_levels_microV, bins=np.arange(5, 30, 2.5))
ax.set_xlabel('noise [microV]')
ax.set_xlabel("noise [microV]")

# ### Detect and localize peaks
#
Expand All @@ -168,15 +174,16 @@
# +
from spikeinterface.sortingcomponents.peak_detection import detect_peaks

job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)
peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16,
detect_threshold=5, radius_um=50., **job_kwargs)
job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True)
peaks = detect_peaks(
rec, method="locally_exclusive", noise_levels=noise_levels_int16, detect_threshold=5, radius_um=50.0, **job_kwargs
)
peaks

# +
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs)
peak_locations = localize_peaks(rec, peaks, method="center_of_mass", radius_um=50.0, **job_kwargs)
# -

# ### Check for drift
Expand All @@ -190,7 +197,7 @@
# check for drift
fs = rec.sampling_frequency
fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(peaks['sample_index'] / fs, peak_locations['y'], color='k', marker='.', alpha=0.002)
ax.scatter(peaks["sample_index"] / fs, peak_locations["y"], color="k", marker=".", alpha=0.002)


# +
Expand All @@ -199,7 +206,7 @@
si.plot_probe_map(rec, ax=ax, with_channel_ids=True)
ax.set_ylim(-100, 150)

ax.scatter(peak_locations['x'], peak_locations['y'], color='purple', alpha=0.002)
ax.scatter(peak_locations["x"], peak_locations["y"], color="purple", alpha=0.002)
# -

# ## Run a spike sorter
Expand All @@ -222,18 +229,24 @@
#

# check default params for kilosort2.5
si.get_default_sorter_params('kilosort2_5')
si.get_default_sorter_params("kilosort2_5")

# +
# run kilosort2.5 without drift correction
params_kilosort2_5 = {'do_correction': False}

sorting = si.run_sorter('kilosort2_5', rec, output_folder=base_folder / 'kilosort2.5_output',
docker_image=True, verbose=True, **params_kilosort2_5)
params_kilosort2_5 = {"do_correction": False}

sorting = si.run_sorter(
"kilosort2_5",
rec,
output_folder=base_folder / "kilosort2.5_output",
docker_image=True,
verbose=True,
**params_kilosort2_5,
)
# -

# the results can be read back for future sessions
sorting = si.read_sorter_folder(base_folder / 'kilosort2.5_output')
sorting = si.read_sorter_folder(base_folder / "kilosort2.5_output")

# here we have 31 units in our recording
sorting
Expand All @@ -247,16 +260,23 @@
# Note that we use the `sparse=True` option. This option is important because the waveforms will be extracted only for a few channels around the main channel of each unit. This saves tons of disk space and speeds up the waveforms extraction and further processing.
#

we = si.extract_waveforms(rec, sorting, folder=base_folder / 'waveforms_kilosort2.5',
sparse=True, max_spikes_per_unit=500, ms_before=1.5,ms_after=2.,
**job_kwargs)
we = si.extract_waveforms(
rec,
sorting,
folder=base_folder / "waveforms_kilosort2.5",
sparse=True,
max_spikes_per_unit=500,
ms_before=1.5,
ms_after=2.0,
**job_kwargs,
)

# the `WaveformExtractor` contains all information and is persistent on disk
print(we)
print(we.folder)

# the `WaveformExtrator` can be easily loaded back from its folder
we = si.load_waveforms(base_folder / 'waveforms_kilosort2.5')
we = si.load_waveforms(base_folder / "waveforms_kilosort2.5")
we

# Many additional computations rely on the `WaveformExtractor`.
Expand All @@ -281,8 +301,9 @@
#
# `si.compute_principal_components(waveform_extractor)`

metrics = si.compute_quality_metrics(we, metric_names=['firing_rate', 'presence_ratio', 'snr',
'isi_violation', 'amplitude_cutoff'])
metrics = si.compute_quality_metrics(
we, metric_names=["firing_rate", "presence_ratio", "snr", "isi_violation", "amplitude_cutoff"]
)
metrics

# ## Curation using metrics
Expand All @@ -306,16 +327,16 @@
#
# In order to export the final results we need to make a copy of the the waveforms, but only for the selected units (so we can avoid computing them again).

we_clean = we.select_units(keep_unit_ids, new_folder=base_folder / 'waveforms_clean')
we_clean = we.select_units(keep_unit_ids, new_folder=base_folder / "waveforms_clean")

we_clean

# Then we export figures to a report folder

# export spike sorting report to a folder
si.export_report(we_clean, base_folder / 'report', format='png')
si.export_report(we_clean, base_folder / "report", format="png")

we_clean = si.load_waveforms(base_folder / 'waveforms_clean')
we_clean = si.load_waveforms(base_folder / "waveforms_clean")
we_clean

# And push the results to sortingview webased viewer
Expand Down
61 changes: 31 additions & 30 deletions examples/how_to/get_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
# Then we can open it. Note that [MEArec](https://mearec.readthedocs.io>) simulated file
# contains both a "recording" and a "sorting" object.

local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
recording, sorting_true = se.read_mearec(local_path)
print(recording)
print(sorting_true)
Expand All @@ -103,10 +103,10 @@
num_chan = recording.get_num_channels()
num_seg = recording.get_num_segments()

print('Channel ids:', channel_ids)
print('Sampling frequency:', fs)
print('Number of channels:', num_chan)
print('Number of segments:', num_seg)
print("Channel ids:", channel_ids)
print("Sampling frequency:", fs)
print("Number of channels:", num_chan)
print("Number of segments:", num_seg)
# -

# ...and from a `BaseSorting`
Expand All @@ -116,9 +116,9 @@
unit_ids = sorting_true.get_unit_ids()
spike_train = sorting_true.get_unit_spike_train(unit_id=unit_ids[0])

print('Number of segments:', num_seg)
print('Unit ids:', unit_ids)
print('Spike train of first unit:', spike_train)
print("Number of segments:", num_seg)
print("Unit ids:", unit_ids)
print("Spike train of first unit:", spike_train)
# -

# SpikeInterface internally uses the [`ProbeInterface`](https://probeinterface.readthedocs.io/en/main/) to handle `probeinterface.Probe` and
Expand All @@ -144,29 +144,29 @@
recording_cmr = recording
recording_f = si.bandpass_filter(recording, freq_min=300, freq_max=6000)
print(recording_f)
recording_cmr = si.common_reference(recording_f, reference='global', operator='median')
recording_cmr = si.common_reference(recording_f, reference="global", operator="median")
print(recording_cmr)

# this computes and saves the recording after applying the preprocessing chain
recording_preprocessed = recording_cmr.save(format='binary')
recording_preprocessed = recording_cmr.save(format="binary")
print(recording_preprocessed)
# -

# Now you are ready to spike sort using the `spikeinterface.sorters` module!
# Let's first check which sorters are implemented and which are installed

print('Available sorters', ss.available_sorters())
print('Installed sorters', ss.installed_sorters())
print("Available sorters", ss.available_sorters())
print("Installed sorters", ss.installed_sorters())

# The `ss.installed_sorters()` will list the sorters installed on the machine.
# We can see we have HerdingSpikes and Tridesclous installed.
# Spike sorters come with a set of parameters that users can change.
# The available parameters are dictionaries and can be accessed with:

print("Tridesclous params:")
pprint(ss.get_default_sorter_params('tridesclous'))
pprint(ss.get_default_sorter_params("tridesclous"))
print("SpykingCircus2 params:")
pprint(ss.get_default_sorter_params('spykingcircus2'))
pprint(ss.get_default_sorter_params("spykingcircus2"))

# Let's run `tridesclous` and change one of the parameters, say, the `detect_threshold`:

Expand All @@ -176,12 +176,13 @@
# Alternatively we can pass a full dictionary containing the parameters:

# +
other_params = ss.get_default_sorter_params('tridesclous')
other_params['detect_threshold'] = 6
other_params = ss.get_default_sorter_params("tridesclous")
other_params["detect_threshold"] = 6

# parameters set by params dictionary
sorting_TDC_2 = ss.run_sorter(sorter_name="tridesclous", recording=recording_preprocessed,
output_folder="tdc_output2", **other_params)
sorting_TDC_2 = ss.run_sorter(
sorter_name="tridesclous", recording=recording_preprocessed, output_folder="tdc_output2", **other_params
)
print(sorting_TDC_2)
# -

Expand All @@ -192,21 +193,20 @@

# The `sorting_TDC` and `sorting_SC2` are `BaseSorting` objects. We can print the units found using:

print('Units found by tridesclous:', sorting_TDC.get_unit_ids())
print('Units found by spyking-circus2:', sorting_SC2.get_unit_ids())
print("Units found by tridesclous:", sorting_TDC.get_unit_ids())
print("Units found by spyking-circus2:", sorting_SC2.get_unit_ids())

# If a sorter is not installed locally, we can also avoid installing it and run it anyways, using a container (Docker or Singularity). For example, let's run `Kilosort2` using Docker:

sorting_KS2 = ss.run_sorter(sorter_name="kilosort2", recording=recording_preprocessed,
docker_image=True, verbose=True)
sorting_KS2 = ss.run_sorter(sorter_name="kilosort2", recording=recording_preprocessed, docker_image=True, verbose=True)
print(sorting_KS2)

# SpikeInterface provides a efficient way to extract waveforms from paired recording/sorting objects.
# The `extract_waveforms` function samples some spikes (by default `max_spikes_per_unit=500`)
# for each unit, extracts their waveforms, and stores them to disk. These waveforms are helpful to compute the average waveform, or "template", for each unit and then to compute, for example, quality metrics.

# +
we_TDC = si.extract_waveforms(recording_preprocessed, sorting_TDC, 'waveforms_folder', overwrite=True)
we_TDC = si.extract_waveforms(recording_preprocessed, sorting_TDC, "waveforms_folder", overwrite=True)
print(we_TDC)

unit_id0 = sorting_TDC.unit_ids[0]
Expand Down Expand Up @@ -236,7 +236,7 @@

# Importantly, waveform extractors (and all extensions) can be reloaded at later times:

we_loaded = si.load_waveforms('waveforms_folder')
we_loaded = si.load_waveforms("waveforms_folder")
print(we_loaded.get_available_extension_names())

# Once we have computed all of the postprocessing information, we can compute quality metrics (different quality metrics require different extensions - e.g., drift metrics require `spike_locations`):
Expand Down Expand Up @@ -277,21 +277,21 @@
# Alternatively, we can export the data locally to Phy. [Phy](<https://github.com/cortex-lab/phy>) is a GUI for manual
# curation of the spike sorting output. To export to phy you can run:

sexp.export_to_phy(we_TDC, 'phy_folder_for_TDC', verbose=True)
sexp.export_to_phy(we_TDC, "phy_folder_for_TDC", verbose=True)

# Then you can run the template-gui with: `phy template-gui phy_folder_for_TDC/params.py`
# and manually curate the results.

# After curating with Phy, the curated sorting can be reloaded to SpikeInterface. In this case, we exclude the units that have been labeled as "noise":

sorting_curated_phy = se.read_phy('phy_folder_for_TDC', exclude_cluster_groups=["noise"])
sorting_curated_phy = se.read_phy("phy_folder_for_TDC", exclude_cluster_groups=["noise"])

# Quality metrics can be also used to automatically curate the spike sorting
# output. For example, you can select sorted units with a SNR above a
# certain threshold:

# +
keep_mask = (qm['snr'] > 10) & (qm['isi_violations_ratio'] < 0.01)
keep_mask = (qm["snr"] > 10) & (qm["isi_violations_ratio"] < 0.01)
print("Mask:", keep_mask.values)

sorting_curated_auto = sorting_TDC.select_units(sorting_TDC.unit_ids[keep_mask])
Expand All @@ -310,8 +310,9 @@

comp_gt = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true, tested_sorting=sorting_TDC)
comp_pair = sc.compare_two_sorters(sorting1=sorting_TDC, sorting2=sorting_SC2)
comp_multi = sc.compare_multiple_sorters(sorting_list=[sorting_TDC, sorting_SC2, sorting_KS2],
name_list=['tdc', 'sc2', 'ks2'])
comp_multi = sc.compare_multiple_sorters(
sorting_list=[sorting_TDC, sorting_SC2, sorting_KS2], name_list=["tdc", "sc2", "ks2"]
)

# When comparing with a ground-truth sorting (1,), you can get the sorting performance and plot a confusion
# matrix
Expand All @@ -335,7 +336,7 @@
# +
sorting_agreement = comp_multi.get_agreement_sorting(minimum_agreement_count=2)

print('Units in agreement between TDC, SC2, and KS2:', sorting_agreement.get_unit_ids())
print("Units in agreement between TDC, SC2, and KS2:", sorting_agreement.get_unit_ids())

w_multi = sw.plot_multicomparison_agreement(comp_multi)
w_multi = sw.plot_multicomparison_agreement_by_sorter(comp_multi)
Expand Down
Loading

0 comments on commit 5cd00e4

Please sign in to comment.