Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kilosort4 Wrapper #2529

Merged
merged 7 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 307 additions & 0 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
from __future__ import annotations

from pathlib import Path
import os
from typing import Union

from ..basesorter import BaseSorter
from .kilosortbase import KilosortBase

PathType = Union[str, Path]


class Kilosort4Sorter(BaseSorter):
"""Kilosort4 Sorter object."""

sorter_name: str = "kilosort4"
requires_locations = True

_default_params = {
"n_chan_bin": 385,
"nblocks": 1,
"Th_universal": 9,
"Th_learned": 8,
"do_CAR": True,
"invert_sign": False,
"nt": 61,
"artifact_threshold": None,
"nskip": 25,
"whitening_range": 32,
"binning_depth": 5,
"sig_interp": 20,
"nt0min": None,
"dmin": None,
"dminx": None,
"min_template_size": 10,
"template_sizes": 5,
"nearest_chans": 10,
"nearest_templates": 100,
"templates_from_data": True,
"n_templates": 6,
"n_pcs": 6,
"Th_single_ch": 6,
"acg_threshold": 0.2,
"ccg_threshold": 0.25,
"cluster_downsampling": 20,
"cluster_pcs": 64,
"duplicate_spike_bins": 15,
"do_correction": True,
"keep_good_only": False,
"save_extra_kwargs": False,
"skip_kilosort_preprocessing": False,
"scaleproc": None,
}

_params_description = {
"batch_size": "Number of samples included in each batch of data. Default value: 60000.",
"nblocks": "Number of non-overlapping blocks for drift correction (additional nblocks-1 blocks are created in the overlaps). Default value: 1.",
"Th_universal": "Spike detection threshold for universal templates. Th(1) in previous versions of Kilosort. Default value: 9.",
"Th_learned": "Spike detection threshold for learned templates. Th(2) in previous versions of Kilosort. Default value: 8.",
"do_CAR": "Whether to perform common average reference. Default value: True.",
"invert_sign": "Invert the sign of the data. Default value: False.",
"nt": "Number of samples per waveform. Also size of symmetric padding for filtering. Default value: 61.",
"artifact_threshold": "If a batch contains absolute values above this number, it will be zeroed out under the assumption that a recording artifact is present. By default, the threshold is infinite (so that no zeroing occurs). Default value: None.",
"nskip": "Batch stride for computing whitening matrix. Default value: 25.",
"whitening_range": "Number of nearby channels used to estimate the whitening matrix. Default value: 32.",
"binning_depth": "For drift correction, vertical bin size in microns used for 2D histogram. Default value: 5.",
"sig_interp": "For drift correction, sigma for interpolation (spatial standard deviation). Approximate smoothness scale in units of microns. Default value: 20.",
"nt0min": "Sample index for aligning waveforms, so that their minimum or maximum value happens here. Default of 20. Default value: None.",
"dmin": "Vertical spacing of template centers used for spike detection, in microns. Determined automatically by default. Default value: None.",
"dminx": "Horizontal spacing of template centers used for spike detection, in microns. Determined automatically by default. Default value: None.",
"min_template_size": "Standard deviation of the smallest, spatial envelope Gaussian used for universal templates. Default value: 10.",
"template_sizes": "Number of sizes for universal spike templates (multiples of the min_template_size). Default value: 5.",
"nearest_chans": "Number of nearest channels to consider when finding local maxima during spike detection. Default value: 10.",
"nearest_templates": "Number of nearest spike template locations to consider when finding local maxima during spike detection. Default value: 100.",
"templates_from_data": "Indicates whether spike shapes used in universal templates should be estimated from the data or loaded from the predefined templates. Default value: True.",
"n_templates": "Number of single-channel templates to use for the universal templates (only used if templates_from_data is True). Default value: 6.",
"n_pcs": "Number of single-channel PCs to use for extracting spike features (only used if templates_from_data is True). Default value: 6.",
"Th_single_ch": "For single channel threshold crossings to compute universal- templates. In units of whitened data standard deviations. Default value: 6.",
"acg_threshold": 'Fraction of refractory period violations that are allowed in the ACG compared to baseline; used to assign "good" units. Default value: 0.2.',
"ccg_threshold": "Fraction of refractory period violations that are allowed in the CCG compared to baseline; used to perform splits and merges. Default value: 0.25.",
"cluster_downsampling": "Inverse fraction of nodes used as landmarks during clustering (can be 1, but that slows down the optimization). Default value: 20.",
"cluster_pcs": "Maximum number of spatiotemporal PC features used for clustering. Default value: 64.",
"duplicate_spike_bins": "Number of bins for which subsequent spikes from the same cluster are assumed to be artifacts. A value of 0 disables this step. Default value: 15.",
"keep_good_only": "If True only 'good' units are returned",
"do_correction": "If True, drift correction is performed",
"save_extra_kwargs": "If True, additional kwargs are saved to the output",
"skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing",
"scaleproc": "int16 scaling of whitened data, if None set to 200.",
}

sorter_description = """Kilosort4 is a GPU-accelerated and efficient template-matching spike sorter. On top of its
predecessor Kilosort, it implements a drift-correction strategy. Kilosort3 improves on Kilosort2 primarily in the
type of drift correction we use. Where Kilosort2 modified templates as a function of time/drift (a drift tracking
approach), Kilosort3 corrects the raw data directly via a sub-pixel registration process (a drift correction
approach). Kilosort3 has not been as broadly tested as Kilosort2, but is expected to work out of the box on
Neuropixels 1.0 and 2.0 probes, as well as other probes with vertical pitch <=40um. For other recording methods,
like tetrodes or single-channel recordings, you should test empirically if v3 or v2.0 works better for you (use
the "releases" on the github page to download older versions).
For more information see https://github.com/MouseLand/Kilosort"""
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

installation_mesg = """\nTo use Kilosort4 run:\n
>>> git clone https://github.com/MouseLand/Kilosort
>>> cd Kilosort
>>> pip install .

More information on Kilosort4 and its installation procedure at:
https://github.com/MouseLand/Kilosort
"""

handle_multi_segment = False

@classmethod
def is_installed(cls):
try:
import kilosort as ks
import torch

HAVE_KS = True
except ImportError:
HAVE_KS = False
return HAVE_KS

@classmethod
def get_sorter_version(cls):
import kilosort as ks

return ks.__version__

@classmethod
def _setup_recording(cls, recording, sorter_output_folder, params, verbose):
from probeinterface import write_prb

pg = recording.get_probegroup()
probe_filename = sorter_output_folder / "probe.prb"
write_prb(probe_filename, pg)

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
from kilosort.run_kilosort import (
set_files,
initialize_ops,
compute_preprocessing,
compute_drift_correction,
detect_spikes,
cluster_spikes,
save_sorting,
get_run_parameters,
)
from kilosort.io import load_probe, RecordingExtractorAsArray, BinaryFiltered
from kilosort.parameters import DEFAULT_SETTINGS

import time
import torch
import numpy as np

sorter_output_folder = sorter_output_folder.absolute()

probe_filename = sorter_output_folder / "probe.prb"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load probe
recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)
probe = load_probe(probe_filename)
probe_name = ""
filename = ""

# this internally concatenates the recording
file_object = RecordingExtractorAsArray(recording)

do_CAR = params["do_CAR"]
invert_sign = params["invert_sign"]
save_extra_vars = params["save_extra_kwargs"]
progress_bar = None
settings_ks = {k: v for k, v in params.items() if k in DEFAULT_SETTINGS}
settings_ks["n_chan_bin"] = recording.get_num_channels()
settings_ks["fs"] = recording.sampling_frequency
if not do_CAR:
print("Skipping common average reference.")

tic0 = time.time()

settings = {**DEFAULT_SETTINGS, **settings_ks}

if settings["nt0min"] is None:
settings["nt0min"] = int(20 * settings["nt"] / 61)
if settings["artifact_threshold"] is None:
settings["artifact_threshold"] = np.inf

# NOTE: Also modifies settings in-place
data_dir = ""
results_dir = sorter_output_folder
filename, data_dir, results_dir, probe = set_files(settings, filename, probe, probe_name, data_dir, results_dir)
ops = initialize_ops(settings, probe, recording.get_dtype(), do_CAR, invert_sign, device)

n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, _, _, tmin, tmax, artifact = (
get_run_parameters(ops)
)
# Set preprocessing and drift correction parameters
if not params["skip_kilosort_preprocessing"]:
ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object)
else:
print("Skipping kilosort preprocessing.")
bfile = BinaryFiltered(
ops["filename"],
n_chan_bin,
fs,
NT,
nt,
twav_min,
chan_map,
hp_filter=None,
device=device,
do_CAR=do_CAR,
invert_sign=invert,
dtype=dtype,
tmin=tmin,
tmax=tmax,
artifact_threshold=artifact,
file_object=file_object,
)
ops["preprocessing"] = dict(hp_filter=None, whiten_mat=None)
ops["Wrot"] = torch.as_tensor(np.eye(recording.get_num_channels()))
ops["Nbatches"] = bfile.n_batches

np.random.seed(1)
torch.cuda.manual_seed_all(1)
torch.random.manual_seed(1)
Comment on lines +218 to +220
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a great way!

# if not params["skip_kilosort_preprocessing"]:
if params["do_correction"]:
# this function applies both preprocessing and drift correction
ops, bfile, st0 = compute_drift_correction(
ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object
)
else:
print("Skipping drift correction.")
hp_filter = ops["preprocessing"]["hp_filter"]
whiten_mat = ops["preprocessing"]["whiten_mat"]

bfile = BinaryFiltered(
ops["filename"],
n_chan_bin,
fs,
NT,
nt,
twav_min,
chan_map,
hp_filter=hp_filter,
whiten_mat=whiten_mat,
device=device,
do_CAR=do_CAR,
invert_sign=invert,
dtype=dtype,
tmin=tmin,
tmax=tmax,
artifact_threshold=artifact,
file_object=file_object,
)

# TODO: don't think we need to do this actually
# Save intermediate `ops` for use by GUI plots
# io.save_ops(ops, results_dir)

# Sort spikes and save results
st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar)
clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar)
if params["skip_kilosort_preprocessing"]:
ops["preprocessing"] = dict(
hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels()))
)
ops, similar_templates, is_ref, est_contam_rate = save_sorting(
ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars
)

# # Clean-up temporary files
# if params["delete_recording_dat"] and (recording_file := sorter_output_folder / "recording.dat").exists():
# recording_file.unlink()

# all_tmp_files = ("matlab_files", "temp_wh.dat")

# if isinstance(params["delete_tmp_files"], bool):
# if params["delete_tmp_files"]:
# tmp_files_to_remove = all_tmp_files
# else:
# tmp_files_to_remove = ()
# else:
# assert isinstance(
# params["delete_tmp_files"], (tuple, list)
# ), "`delete_tmp_files` must be a `Bool`, `Tuple` or `List`."

# for name in params["delete_tmp_files"]:
# assert name in all_tmp_files, f"{name} is not a valid option, must be one of: {all_tmp_files}"

# tmp_files_to_remove = params["delete_tmp_files"]

# if "temp_wh.dat" in tmp_files_to_remove:
# if (temp_wh_file := sorter_output_folder / "temp_wh.dat").exists():
# temp_wh_file.unlink()

# if "matlab_files" in tmp_files_to_remove:
# for ext in ["*.m", "*.mat"]:
# for temp_file in sorter_output_folder.glob(ext):
# temp_file.unlink()

@classmethod
def _get_result_from_folder(cls, sorter_output_folder):
return KilosortBase._get_result_from_folder(sorter_output_folder)
Loading
Loading