Skip to content

Commit

Permalink
Merge branch 'SpikeInterface:main' into regularize_whitening
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored May 25, 2024
2 parents 842aaca + 790715c commit d693d13
Show file tree
Hide file tree
Showing 18 changed files with 134 additions and 72 deletions.
13 changes: 7 additions & 6 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ class ComputeRandomSpikes(AnalyzerExtension):
use_nodepipeline = False
need_job_kwargs = False

def _run(
self,
):
def _run(self, verbose=False):

self.data["random_spikes_indices"] = random_spikes_selection(
self.sorting_analyzer.sorting,
num_samples=self.sorting_analyzer.rec_attributes["num_samples"],
Expand Down Expand Up @@ -145,7 +144,7 @@ def nbefore(self):
def nafter(self):
return int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
self.data.clear()

recording = self.sorting_analyzer.recording
Expand Down Expand Up @@ -183,6 +182,7 @@ def _run(self, **job_kwargs):
sparsity_mask=sparsity_mask,
copy=copy,
job_name="compute_waveforms",
verbose=verbose,
**job_kwargs,
)

Expand Down Expand Up @@ -311,7 +311,7 @@ def _set_params(self, ms_before: float = 1.0, ms_after: float = 2.0, operators=N
)
return params

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
self.data.clear()

if self.sorting_analyzer.has_extension("waveforms"):
Expand Down Expand Up @@ -339,6 +339,7 @@ def _run(self, **job_kwargs):
self.nafter,
return_scaled=return_scaled,
return_std=return_std,
verbose=verbose,
**job_kwargs,
)

Expand Down Expand Up @@ -581,7 +582,7 @@ def _select_extension_data(self, unit_ids):
# this do not depend on units
return self.data

def _run(self):
def _run(self, verbose=False):
self.data["noise_levels"] = get_noise_levels(
self.sorting_analyzer.recording, return_scaled=self.sorting_analyzer.return_scaled, **self.params
)
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def run_node_pipeline(
squeeze_output=True,
folder=None,
names=None,
verbose=False,
):
"""
Common function to run pipeline with peak detector or already detected peak.
Expand All @@ -499,6 +500,7 @@ def run_node_pipeline(
init_args,
gather_func=gather_func,
job_name=job_name,
verbose=verbose,
**job_kwargs,
)

Expand Down
28 changes: 16 additions & 12 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def get_num_units(self) -> int:
return self.sorting.get_num_units()

## extensions zone
def compute(self, input, save=True, extension_params=None, **kwargs):
def compute(self, input, save=True, extension_params=None, verbose=False, **kwargs):
"""
Compute one extension or several extensiosn.
Internally calls compute_one_extension() or compute_several_extensions() depending on the input type.
Expand Down Expand Up @@ -883,11 +883,11 @@ def compute(self, input, save=True, extension_params=None, **kwargs):
)
"""
if isinstance(input, str):
return self.compute_one_extension(extension_name=input, save=save, **kwargs)
return self.compute_one_extension(extension_name=input, save=save, verbose=verbose, **kwargs)
elif isinstance(input, dict):
params_, job_kwargs = split_job_kwargs(kwargs)
assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()"
self.compute_several_extensions(extensions=input, save=save, **job_kwargs)
self.compute_several_extensions(extensions=input, save=save, verbose=verbose, **job_kwargs)
elif isinstance(input, list):
params_, job_kwargs = split_job_kwargs(kwargs)
assert len(params_) == 0, "Too many arguments for SortingAnalyzer.compute_several_extensions()"
Expand All @@ -898,11 +898,11 @@ def compute(self, input, save=True, extension_params=None, **kwargs):
ext_name in input
), f"SortingAnalyzer.compute(): Parameters specified for {ext_name}, which is not in the specified {input}"
extensions[ext_name] = ext_params
self.compute_several_extensions(extensions=extensions, save=save, **job_kwargs)
self.compute_several_extensions(extensions=extensions, save=save, verbose=verbose, **job_kwargs)
else:
raise ValueError("SortingAnalyzer.compute() need str, dict or list")

def compute_one_extension(self, extension_name, save=True, **kwargs):
def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs):
"""
Compute one extension.
Expand All @@ -925,7 +925,7 @@ def compute_one_extension(self, extension_name, save=True, **kwargs):
Returns
-------
result_extension: AnalyzerExtension
Return the extension instance.
Return the extension instance
Examples
--------
Expand Down Expand Up @@ -961,13 +961,16 @@ def compute_one_extension(self, extension_name, save=True, **kwargs):

extension_instance = extension_class(self)
extension_instance.set_params(save=save, **params)
extension_instance.run(save=save, **job_kwargs)
if extension_class.need_job_kwargs:
extension_instance.run(save=save, verbose=verbose, **job_kwargs)
else:
extension_instance.run(save=save, verbose=verbose)

self.extensions[extension_name] = extension_instance

return extension_instance

def compute_several_extensions(self, extensions, save=True, **job_kwargs):
def compute_several_extensions(self, extensions, save=True, verbose=False, **job_kwargs):
"""
Compute several extensions
Expand Down Expand Up @@ -1021,9 +1024,9 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
for extension_name, extension_params in extensions_without_pipeline.items():
extension_class = get_extension_class(extension_name)
if extension_class.need_job_kwargs:
self.compute_one_extension(extension_name, save=save, **extension_params, **job_kwargs)
self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs)
else:
self.compute_one_extension(extension_name, save=save, **extension_params)
self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params)
# then extensions with pipeline
if len(extensions_with_pipeline) > 0:
all_nodes = []
Expand Down Expand Up @@ -1053,6 +1056,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
job_name=job_name,
gather_mode="memory",
squeeze_output=False,
verbose=verbose,
)

for r, result in enumerate(results):
Expand All @@ -1071,9 +1075,9 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
for extension_name, extension_params in extensions_post_pipeline.items():
extension_class = get_extension_class(extension_name)
if extension_class.need_job_kwargs:
self.compute_one_extension(extension_name, save=save, **extension_params, **job_kwargs)
self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params, **job_kwargs)
else:
self.compute_one_extension(extension_name, save=save, **extension_params)
self.compute_one_extension(extension_name, save=save, verbose=verbose, **extension_params)

def get_saved_extension_names(self):
"""
Expand Down
15 changes: 12 additions & 3 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def distribute_waveforms_to_buffers(
mode="memmap",
sparsity_mask=None,
job_name=None,
verbose=False,
**job_kwargs,
):
"""
Expand Down Expand Up @@ -281,7 +282,9 @@ def distribute_waveforms_to_buffers(
)
if job_name is None:
job_name = f"extract waveforms {mode} multi buffer"
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs)
processor = ChunkRecordingExecutor(
recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs
)
processor.run()


Expand Down Expand Up @@ -410,6 +413,7 @@ def extract_waveforms_to_single_buffer(
sparsity_mask=None,
copy=True,
job_name=None,
verbose=False,
**job_kwargs,
):
"""
Expand Down Expand Up @@ -523,7 +527,9 @@ def extract_waveforms_to_single_buffer(
if job_name is None:
job_name = f"extract waveforms {mode} mono buffer"

processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs)
processor = ChunkRecordingExecutor(
recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs
)
processor.run()

if mode == "memmap":
Expand Down Expand Up @@ -783,6 +789,7 @@ def estimate_templates_with_accumulator(
return_scaled: bool = True,
job_name=None,
return_std: bool = False,
verbose: bool = False,
**job_kwargs,
):
"""
Expand Down Expand Up @@ -861,7 +868,9 @@ def estimate_templates_with_accumulator(

if job_name is None:
job_name = "estimate_templates_with_accumulator"
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs)
processor = ChunkRecordingExecutor(
recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs
)
processor.run()

# average
Expand Down
17 changes: 12 additions & 5 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,8 @@ class NwbRecordingExtractor(BaseRecording):
stream_cache_path: str, Path, or None, default: None
Specifies the local path for caching the file. Relevant only if `cache` is True.
storage_options: dict | None = None,
Additional parameters for the storage backend (e.g. AWS credentials) used for "zarr" stream_mode.
These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function.
This is only used on the "zarr" stream_mode.
use_pynwb: bool, default: False
Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py
to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations.
Expand Down Expand Up @@ -861,8 +862,10 @@ def _fetch_main_properties_backend(self):

@staticmethod
def fetch_available_electrical_series_paths(
file_path: str | Path, stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None
) -> List[str]:
file_path: str | Path,
stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None,
storage_options: dict | None = None,
) -> list[str]:
"""
Retrieves the paths to all ElectricalSeries objects within a neurodata file.
Expand All @@ -873,7 +876,9 @@ def fetch_available_electrical_series_paths(
stream_mode : "fsspec" | "remfile" | "zarr" | None, optional
Determines the streaming mode for reading the file. Use this for optimized reading from
different sources, such as local disk or remote servers.
storage_options: dict | None = None,
These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function.
This is only used on the "zarr" stream_mode.
Returns
-------
list of str
Expand Down Expand Up @@ -901,6 +906,7 @@ def fetch_available_electrical_series_paths(
file_handle = read_file_from_backend(
file_path=file_path,
stream_mode=stream_mode,
storage_options=storage_options,
)

electrical_series_paths = _find_neurodata_type_from_backend(
Expand Down Expand Up @@ -988,7 +994,8 @@ class NwbSortingExtractor(BaseSorting):
If True, the file is cached in the file passed to stream_cache_path
if False, the file is not cached.
storage_options: dict | None = None,
Additional parameters for the storage backend (e.g. AWS credentials) used for "zarr" stream_mode.
These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function.
This is only used on the "zarr" stream_mode.
use_pynwb: bool, default: False
Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py
to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations.
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _get_pipeline_nodes(self):
nodes = [spike_retriever_node, amplitude_scalings_node]
return nodes

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
job_kwargs = fix_job_kwargs(job_kwargs)
nodes = self.get_pipeline_nodes()
amp_scalings, collision_mask = run_node_pipeline(
Expand All @@ -190,6 +190,7 @@ def _run(self, **job_kwargs):
job_kwargs=job_kwargs,
job_name="amplitude_scalings",
gather_mode="memory",
verbose=verbose,
)
self.data["amplitude_scalings"] = amp_scalings
if self.params["handle_collisions"]:
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _select_extension_data(self, unit_ids):
new_data = dict(ccgs=new_ccgs, bins=new_bins)
return new_data

def _run(self):
def _run(self, verbose=False):
ccgs, bins = compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params)
self.data["ccgs"] = ccgs
self.data["bins"] = bins
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/isi.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _select_extension_data(self, unit_ids):
new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins)
return new_extension_data

def _run(self):
def _run(self, verbose=False):
isi_histograms, bins = _compute_isi_histograms(self.sorting_analyzer.sorting, **self.params)
self.data["isi_histograms"] = isi_histograms
self.data["bins"] = bins
Expand Down
8 changes: 5 additions & 3 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def project_new(self, new_spikes, new_waveforms, progress_bar=True):
new_projections = self._transform_waveforms(new_spikes, new_waveforms, pca_model, progress_bar=progress_bar)
return new_projections

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
"""
Compute the PCs on waveforms extacted within the by ComputeWaveforms.
Projections are computed only on the waveforms sampled by the SortingAnalyzer.
Expand Down Expand Up @@ -295,7 +295,7 @@ def _run(self, **job_kwargs):
def _get_data(self):
return self.data["pca_projection"]

def run_for_all_spikes(self, file_path=None, **job_kwargs):
def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):
"""
Project all spikes from the sorting on the PCA model.
This is a long computation because waveform need to be extracted from each spikes.
Expand Down Expand Up @@ -359,7 +359,9 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs):
unit_channels,
pca_model,
)
processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name="extract PCs", **job_kwargs)
processor = ChunkRecordingExecutor(
recording, func, init_func, init_args, job_name="extract PCs", verbose=verbose, **job_kwargs
)
processor.run()

def _fit_by_channel_local(self, n_jobs, progress_bar):
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _get_pipeline_nodes(self):
nodes = [spike_retriever_node, spike_amplitudes_node]
return nodes

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
job_kwargs = fix_job_kwargs(job_kwargs)
nodes = self.get_pipeline_nodes()
amps = run_node_pipeline(
Expand All @@ -116,6 +116,7 @@ def _run(self, **job_kwargs):
job_kwargs=job_kwargs,
job_name="spike_amplitudes",
gather_mode="memory",
verbose=False,
)
self.data["amplitudes"] = amps

Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _get_pipeline_nodes(self):
)
return nodes

def _run(self, **job_kwargs):
def _run(self, verbose=False, **job_kwargs):
job_kwargs = fix_job_kwargs(job_kwargs)
nodes = self.get_pipeline_nodes()
spike_locations = run_node_pipeline(
Expand All @@ -129,6 +129,7 @@ def _run(self, **job_kwargs):
job_kwargs=job_kwargs,
job_name="spike_locations",
gather_mode="memory",
verbose=verbose,
)
self.data["spike_locations"] = spike_locations

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _select_extension_data(self, unit_ids):
new_metrics = self.data["metrics"].loc[np.array(unit_ids)]
return dict(metrics=new_metrics)

def _run(self):
def _run(self, verbose=False):
import pandas as pd
from scipy.signal import resample_poly

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _select_extension_data(self, unit_ids):
new_similarity = self.data["similarity"][unit_indices][:, unit_indices]
return dict(similarity=new_similarity)

def _run(self):
def _run(self, verbose=False):
templates_array = get_dense_templates_array(
self.sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled
)
Expand Down
Loading

0 comments on commit d693d13

Please sign in to comment.