Skip to content

Commit

Permalink
Prepare release 0.100.6
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Apr 29, 2024
1 parent 449e219 commit a5cb02f
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 67 deletions.
12 changes: 12 additions & 0 deletions doc/releases/0.100.6.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.. _release0.100.6:

SpikeInterface 0.100.6 release notes
------------------------------------

30th April 2024

Minor release with bug fixes

* Improve caching of MS5 sorter (#2690)
* Allow for remove_excess_spikes to remove negative spike times (#2716)
* Update ks4 wrapper for newer version>=4.0.3 (#2701, #2774)
1 change: 1 addition & 0 deletions doc/whatisnew.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Release notes
.. toctree::
:maxdepth: 1

releases/0.100.6.rst
releases/0.100.5.rst
releases/0.100.4.rst
releases/0.100.3.rst
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "spikeinterface"
version = "0.100.5"
version = "0.100.6"
authors = [
{ name="Alessio Buccino", email="[email protected]" },
{ name="Samuel Garcia", email="[email protected]" },
Expand Down
81 changes: 15 additions & 66 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from pathlib import Path
import os
from typing import Union
from packaging.version import parse

from ..basesorter import BaseSorter
from .kilosortbase import KilosortBase
Expand Down Expand Up @@ -51,10 +51,11 @@ class Kilosort4Sorter(BaseSorter):
"save_extra_kwargs": False,
"skip_kilosort_preprocessing": False,
"scaleproc": None,
"torch_device": "auto",
}

_params_description = {
"batch_size": "Number of samples per batch. Default value: 60000.",
"batch_size": "Number of samples included in each batch of data.",
"nblocks": "Number of non-overlapping blocks for drift correction (additional nblocks-1 blocks are created in the overlaps). Default value: 1.",
"Th_universal": "Spike detection threshold for universal templates. Th(1) in previous versions of Kilosort. Default value: 9.",
"Th_learned": "Spike detection threshold for learned templates. Th(2) in previous versions of Kilosort. Default value: 8.",
Expand Down Expand Up @@ -87,6 +88,7 @@ class Kilosort4Sorter(BaseSorter):
"save_extra_kwargs": "If True, additional kwargs are saved to the output",
"skip_kilosort_preprocessing": "Can optionally skip the internal kilosort preprocessing",
"scaleproc": "int16 scaling of whitened data, if None set to 200.",
"torch_device": "Select the torch device auto/cuda/cpu",
}

sorter_description = """Kilosort4 is a Python package for spike sorting on GPUs with template matching.
Expand Down Expand Up @@ -152,7 +154,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

probe_filename = sorter_output_folder / "probe.prb"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_device = params["torch_device"]
if torch_device == "auto":
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(torch_device)

# load probe
recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)
Expand Down Expand Up @@ -222,39 +227,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
torch.cuda.manual_seed_all(1)
torch.random.manual_seed(1)
# if not params["skip_kilosort_preprocessing"]:
if params["do_correction"]:
# this function applies both preprocessing and drift correction
ops, bfile, st0 = compute_drift_correction(
ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object
)
else:
if not params["do_correction"]:
print("Skipping drift correction.")
hp_filter = ops["preprocessing"]["hp_filter"]
whiten_mat = ops["preprocessing"]["whiten_mat"]

bfile = BinaryFiltered(
ops["filename"],
n_chan_bin,
fs,
NT,
nt,
twav_min,
chan_map,
hp_filter=hp_filter,
whiten_mat=whiten_mat,
device=device,
do_CAR=do_CAR,
invert_sign=invert,
dtype=dtype,
tmin=tmin,
tmax=tmax,
artifact_threshold=artifact,
file_object=file_object,
)
ops["nblocks"] = 0

# TODO: don't think we need to do this actually
# Save intermediate `ops` for use by GUI plots
# io.save_ops(ops, results_dir)
# this function applies both preprocessing and drift correction
ops, bfile, st0 = compute_drift_correction(
ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object
)

# Sort spikes and save results
st, tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar)
Expand All @@ -263,39 +243,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
ops["preprocessing"] = dict(
hp_filter=torch.as_tensor(np.zeros(1)), whiten_mat=torch.as_tensor(np.eye(recording.get_num_channels()))
)
ops, similar_templates, is_ref, est_contam_rate = save_sorting(
ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars
)

# # Clean-up temporary files
# if params["delete_recording_dat"] and (recording_file := sorter_output_folder / "recording.dat").exists():
# recording_file.unlink()

# all_tmp_files = ("matlab_files", "temp_wh.dat")

# if isinstance(params["delete_tmp_files"], bool):
# if params["delete_tmp_files"]:
# tmp_files_to_remove = all_tmp_files
# else:
# tmp_files_to_remove = ()
# else:
# assert isinstance(
# params["delete_tmp_files"], (tuple, list)
# ), "`delete_tmp_files` must be a `Bool`, `Tuple` or `List`."

# for name in params["delete_tmp_files"]:
# assert name in all_tmp_files, f"{name} is not a valid option, must be one of: {all_tmp_files}"

# tmp_files_to_remove = params["delete_tmp_files"]

# if "temp_wh.dat" in tmp_files_to_remove:
# if (temp_wh_file := sorter_output_folder / "temp_wh.dat").exists():
# temp_wh_file.unlink()

# if "matlab_files" in tmp_files_to_remove:
# for ext in ["*.m", "*.mat"]:
# for temp_file in sorter_output_folder.glob(ext):
# temp_file.unlink()
_ = save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars)

@classmethod
def _get_result_from_folder(cls, sorter_output_folder):
Expand Down

0 comments on commit a5cb02f

Please sign in to comment.