Skip to content

Commit

Permalink
Merge branch 'main' into motion_object
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf authored May 31, 2024
2 parents 33e39e4 + 5715d53 commit 5df55f7
Show file tree
Hide file tree
Showing 111 changed files with 1,715 additions and 763 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ full = [
"scikit-learn",
"networkx",
"distinctipy",
"matplotlib",
"matplotlib>=3.6", # matplotlib.colormaps
"cuda-python; platform_system != 'Darwin'",
"numba",
]
Expand Down Expand Up @@ -159,8 +159,8 @@ test = [
]

docs = [
"Sphinx==5.1.1",
"sphinx_rtd_theme==1.0.0",
"Sphinx",
"sphinx_rtd_theme",
"sphinx-gallery",
"sphinx-design",
"numpydoc",
Expand All @@ -173,6 +173,7 @@ docs = [
"hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous
"numba", # For many postprocessing functions
"xarray", # For use of SortingAnalyzer zarr format
"networkx",
# for release we need pypi, so this needs to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pandas as pd

from spikeinterface.extractors import NumpySorting, toy_example
from spikeinterface.extractors import NumpySorting
from spikeinterface.comparison import compare_sorter_to_ground_truth


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest
import numpy as np

from spikeinterface.extractors import NumpySorting, toy_example
from spikeinterface.core import generate_sorting
from spikeinterface.extractors import NumpySorting
from spikeinterface.comparison import compare_multiple_sorters, MultiSortingComparison

if hasattr(pytest, "global_test_folder"):
Expand Down Expand Up @@ -72,7 +73,7 @@ def test_compare_multiple_sorters():

def test_compare_multi_segment():
num_segments = 3
_, sort = toy_example(num_segments=num_segments)
sort = generate_sorting(durations=[10] * num_segments)

cmp_multi = compare_multiple_sorters([sort, sort, sort])

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from spikeinterface.extractors import NumpySorting, toy_example
from spikeinterface.core import generate_sorting
from spikeinterface.extractors import NumpySorting
from spikeinterface.comparison import compare_two_sorters


Expand Down Expand Up @@ -29,7 +30,7 @@ def test_compare_two_sorters():


def test_compare_multi_segment():
_, sort = toy_example(num_segments=2)
sort = generate_sorting(durations=[10, 10])

cmp_multi = compare_two_sorters(sort, sort)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from pathlib import Path
import numpy as np

from spikeinterface.core import create_sorting_analyzer
from spikeinterface.extractors import toy_example
from spikeinterface.core import create_sorting_analyzer, generate_ground_truth_recording
from spikeinterface.comparison import compare_templates, compare_multiple_templates


Expand All @@ -27,9 +26,7 @@ def test_compare_multiple_templates():
duration = 60
num_channels = 8

rec, sort = toy_example(duration=duration, num_segments=1, num_channels=num_channels)
# rec = rec.save(folder=test_dir / "rec")
# sort = sort.save(folder=test_dir / "sort")
rec, sort = generate_ground_truth_recording(durations=[duration], num_channels=num_channels)

# split recording in 3 equal slices
fs = rec.get_sampling_frequency()
Expand Down
32 changes: 21 additions & 11 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ class ComputeRandomSpikes(AnalyzerExtension):
use_nodepipeline = False
need_job_kwargs = False

def _run(
self,
):
def _run(self, verbose=False):

self.data["random_spikes_indices"] = random_spikes_selection(
self.sorting_analyzer.sorting,
num_samples=self.sorting_analyzer.rec_attributes["num_samples"],
Expand Down Expand Up @@ -145,7 +144,7 @@ def nbefore(self):
def nafter(self):
return int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
self.data.clear()

recording = self.sorting_analyzer.recording
Expand Down Expand Up @@ -183,6 +182,7 @@ def _run(self, **job_kwargs):
sparsity_mask=sparsity_mask,
copy=copy,
job_name="compute_waveforms",
verbose=verbose,
**job_kwargs,
)

Expand Down Expand Up @@ -311,7 +311,7 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N
)
return params

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
self.data.clear()

if self.sorting_analyzer.has_extension("waveforms"):
Expand All @@ -330,17 +330,27 @@ def _run(self, **job_kwargs):

return_scaled = self.sorting_analyzer.return_scaled

self.data["average"], self.data["std"] = estimate_templates_with_accumulator(
return_std = "std" in self.params["operators"]
output = estimate_templates_with_accumulator(
recording,
some_spikes,
unit_ids,
self.nbefore,
self.nafter,
return_scaled=return_scaled,
return_std=True,
return_std=return_std,
verbose=verbose,
**job_kwargs,
)

# Output of estimate_templates_with_accumulator is either (templates,) or (templates, stds)
if return_std:
templates, stds = output
self.data["average"] = templates
self.data["std"] = stds
else:
self.data["average"] = output

def _compute_and_append_from_waveforms(self, operators):
if not self.sorting_analyzer.has_extension("waveforms"):
raise ValueError(f"Computing templates with operators {operators} needs the 'waveforms' extension")
Expand Down Expand Up @@ -479,9 +489,9 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
self.params["operators"] += [(operator, percentile)]
templates_array = self.data[key]

if save:
if not self.sorting_analyzer.is_read_only():
self.save()
if save:
if not self.sorting_analyzer.is_read_only():
self.save()

if unit_ids is not None:
unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids)
Expand Down Expand Up @@ -572,7 +582,7 @@ def _select_extension_data(self, unit_ids):
# this do not depend on units
return self.data

def _run(self):
def _run(self, verbose=False):
self.data["noise_levels"] = get_noise_levels(
self.sorting_analyzer.recording, return_scaled=self.sorting_analyzer.return_scaled, **self.params
)
Expand Down
71 changes: 45 additions & 26 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,43 +846,63 @@ def save_to_memory(self, sharedmem=True, **save_kwargs) -> "BaseExtractor":
return cached

# TODO rename to saveto_binary_folder
def save_to_folder(self, name=None, folder=None, overwrite=False, verbose=True, **save_kwargs):
def save_to_folder(
self,
name: str | None = None,
folder: str | Path | None = None,
overwrite: str = False,
verbose: bool = True,
**save_kwargs,
):
"""
Save extractor to folder.
Save the extractor and its data to a folder.
The save consist of:
* extracting traces by calling get_trace() method in chunks
* saving data into file (memmap with BinaryRecordingExtractor)
* dumping to json/pickle the original extractor for provenance
* dumping to json/pickle the cached extractor (memmap with BinaryRecordingExtractor)
This method extracts trace data, saves it to a file (using a memory-mapped approach),
and stores both the original extractor's provenance
and the extractor's metadata in JSON format.
This replaces the use of the old CacheRecordingExtractor and CacheSortingExtractor.
The folder's final location and name can be specified in a couple of ways ways:
There are 2 option for the "folder" argument:
* explicit folder: `extractor.save(folder="/path-for-saving/")`
* explicit sub-folder, implicit base-folder : `extractor.save(name="extarctor_name")`
* generated: `extractor.save()`
1. Explicitly providing the full path:
```
extractor.save_to_folder(folder="/path/to/save/")
```
The second option saves to subfolder "extractor_name" in
"get_global_tmp_folder()". You can set the global tmp folder with:
"set_global_tmp_folder("path-to-global-folder")"
2. Providing a subfolder name, with the base folder being determined automatically:
```
extractor.save_to_folder(name="my_extractor_data")
```
In this case, the data is saved in a subfolder named "my_extractor_data"
within the global temporary folder (set using `set_global_tmp_folder`). If no
global temporary folder is set, one will be generated automatically.
The folder must not exist. If it exists, remove it before.
3. If neither `name` nor `folder` is provided, a random name will be generated
for the subfolder within the global temporary folder.
Parameters
----------
name: None str or Path
Name of the subfolder in get_global_tmp_folder()
If "name" is given, "folder" must be None.
folder: None str or Path
Name of the folder.
If "folder" is given, "name" must be None.
overwrite: bool, default: False
If True, the folder is removed if it already exists
name : str , optional
The name of the subfolder within the global temporary folder. If `folder`
is provided, this argument must be None.
folder : str or Path, optional
The full path of the folder where the data should be saved. If `name` is
provided, this argument must be None.
overwrite : bool, default: False
If True, an existing folder at the specified path will be deleted before saving.
verbose : bool, default: True
If True, print information about the cache folder being used.
**save_kwargs
Additional keyword arguments to be passed to the underlying save method.
Returns
-------
cached: saved copy of the extractor.
cached_extractor
A saved copy of the extractor in the specified format.
Raises
------
AssertionError
If the folder already exists and `overwrite` is False.
"""

if folder is None:
Expand Down Expand Up @@ -925,7 +945,6 @@ def save_to_folder(self, name=None, folder=None, overwrite=False, verbose=True,
self.copy_metadata(cached)

# dump
# cached.dump(folder / f'cached.json', relative_to=folder, folder_metadata=folder)
cached.dump(folder / f"si_folder.json", relative_to=folder)

return cached
Expand Down
16 changes: 13 additions & 3 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,11 @@ def has_scaled_traces(self) -> bool:
bool
True if the recording has scaled traces, False otherwise
"""
warnings.warn(
"`has_scaled_traces` is deprecated and will be removed in 0.103.0. Use has_scaleable_traces() instead",
category=DeprecationWarning,
stacklevel=2,
)
return self.has_scaled()

def get_time_info(self, segment_index=None) -> dict:
Expand Down Expand Up @@ -491,7 +496,7 @@ def time_to_sample_index(self, time_s, segment_index=None):
rs = self._recording_segments[segment_index]
return rs.time_to_sample_index(time_s)

def _save(self, format="binary", **save_kwargs):
def _save(self, format="binary", verbose: bool = False, **save_kwargs):
# handle t_starts
t_starts = []
has_time_vectors = []
Expand All @@ -510,7 +515,7 @@ def _save(self, format="binary", **save_kwargs):
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
dtype = kwargs.get("dtype", None) or self.get_dtype()

write_binary_recording(self, file_paths=file_paths, dtype=dtype, **job_kwargs)
write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs)

from .binaryrecordingextractor import BinaryRecordingExtractor

Expand Down Expand Up @@ -540,14 +545,18 @@ def _save(self, format="binary", **save_kwargs):

cached = SharedMemoryRecording.from_recording(self, **job_kwargs)
else:
from spikeinterface.core import NumpyRecording

cached = NumpyRecording.from_recording(self, **job_kwargs)

elif format == "zarr":
from .zarrextractors import ZarrRecordingExtractor

zarr_path = kwargs.pop("zarr_path")
storage_options = kwargs.pop("storage_options")
ZarrRecordingExtractor.write_recording(self, zarr_path, storage_options, **kwargs, **job_kwargs)
ZarrRecordingExtractor.write_recording(
self, zarr_path, storage_options, verbose=verbose, **kwargs, **job_kwargs
)
cached = ZarrRecordingExtractor(zarr_path, storage_options)

elif format == "nwb":
Expand Down Expand Up @@ -636,6 +645,7 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None):
warnings.warn(
"This method will be removed in version 0.103, use `select_channels` or `rename_channels` instead.",
DeprecationWarning,
stacklevel=2,
)
sub_recording = ChannelSliceRecording(self, channel_ids, renamed_channel_ids=renamed_channel_ids)
return sub_recording
Expand Down
10 changes: 9 additions & 1 deletion src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,20 @@ def get_num_channels(self):
def get_dtype(self):
return self._dtype

def has_scaled(self):
def has_scaleable_traces(self):
if self.get_property("gain_to_uV") is None or self.get_property("offset_to_uV") is None:
return False
else:
return True

def has_scaled(self):
warn(
"`has_scaled` has been deprecated and will be removed in 0.103.0. Please use `has_scaleable_traces()`",
category=DeprecationWarning,
stacklevel=2,
)
return self.has_scaleable_traces()

def has_probe(self):
return "contact_vector" in self.get_property_keys()

Expand Down
9 changes: 7 additions & 2 deletions src/spikeinterface/core/basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ def get_num_segments(self):
return len(self._snippets_segments)

def has_scaled_snippets(self):
return self.has_scaled()
warn(
"`has_scaled_snippets` is deprecated and will be removed in version 0.103.0. Please use `has_scaleable_traces()` instead",
category=DeprecationWarning,
stacklevel=2,
)
return self.has_scaleable_traces()

def get_frames(self, indices=None, segment_index: Union[int, None] = None):
segment_index = self._check_segment_index(segment_index)
Expand All @@ -101,7 +106,7 @@ def get_snippets(
wfs = spts.get_snippets(indices, channel_indices=channel_indices)

if return_scaled:
if not self.has_scaled():
if not self.has_scaleable_traces():
raise ValueError(
"These snippets do not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)"
)
Expand Down
Loading

0 comments on commit 5df55f7

Please sign in to comment.