Skip to content

Commit

Permalink
Merge branch 'main' into verbose_to_non_parallel_case
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored May 24, 2024
2 parents 33d3592 + d18cec2 commit eb03db8
Show file tree
Hide file tree
Showing 49 changed files with 953 additions and 390 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ full = [
"scikit-learn",
"networkx",
"distinctipy",
"matplotlib<3.9", # See https://github.com/SpikeInterface/spikeinterface/issues/2863
"matplotlib>=3.6", # matplotlib.colormaps
"cuda-python; platform_system != 'Darwin'",
"numba",
]
Expand Down Expand Up @@ -159,8 +159,8 @@ test = [
]

docs = [
"Sphinx==5.1.1",
"sphinx_rtd_theme==1.0.0",
"Sphinx",
"sphinx_rtd_theme",
"sphinx-gallery",
"sphinx-design",
"numpydoc",
Expand All @@ -173,6 +173,7 @@ docs = [
"hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous
"numba", # For many postprocessing functions
"xarray", # For use of SortingAnalyzer zarr format
"networkx",
# for release we need pypi, so this needs to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,9 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
self.params["operators"] += [(operator, percentile)]
templates_array = self.data[key]

if save:
if not self.sorting_analyzer.is_read_only():
self.save()
if save:
if not self.sorting_analyzer.is_read_only():
self.save()

if unit_ids is not None:
unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids)
Expand Down
10 changes: 7 additions & 3 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def time_to_sample_index(self, time_s, segment_index=None):
rs = self._recording_segments[segment_index]
return rs.time_to_sample_index(time_s)

def _save(self, format="binary", **save_kwargs):
def _save(self, format="binary", verbose: bool = False, **save_kwargs):
# handle t_starts
t_starts = []
has_time_vectors = []
Expand All @@ -510,7 +510,7 @@ def _save(self, format="binary", **save_kwargs):
file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())]
dtype = kwargs.get("dtype", None) or self.get_dtype()

write_binary_recording(self, file_paths=file_paths, dtype=dtype, **job_kwargs)
write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs)

from .binaryrecordingextractor import BinaryRecordingExtractor

Expand Down Expand Up @@ -540,14 +540,18 @@ def _save(self, format="binary", **save_kwargs):

cached = SharedMemoryRecording.from_recording(self, **job_kwargs)
else:
from spikeinterface.core import NumpyRecording

cached = NumpyRecording.from_recording(self, **job_kwargs)

elif format == "zarr":
from .zarrextractors import ZarrRecordingExtractor

zarr_path = kwargs.pop("zarr_path")
storage_options = kwargs.pop("storage_options")
ZarrRecordingExtractor.write_recording(self, zarr_path, storage_options, **kwargs, **job_kwargs)
ZarrRecordingExtractor.write_recording(
self, zarr_path, storage_options, verbose=verbose, **kwargs, **job_kwargs
)
cached = ZarrRecordingExtractor(zarr_path, storage_options)

elif format == "nwb":
Expand Down
13 changes: 7 additions & 6 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
"chunk_duration",
"progress_bar",
"mp_context",
"verbose",
"max_threads_per_process",
)

Expand Down Expand Up @@ -131,6 +130,8 @@ def split_job_kwargs(mixed_kwargs):
def divide_segment_into_chunks(num_frames, chunk_size):
if chunk_size is None:
chunks = [(0, num_frames)]
elif chunk_size > num_frames:
chunks = [(0, num_frames)]
else:
n = num_frames // chunk_size

Expand Down Expand Up @@ -245,12 +246,12 @@ def ensure_chunk_size(
else:
raise ValueError("chunk_duration must be str or float")
else:
# Edge case to define single chunk per segment for n_jobs=1.
# All chunking parameters equal None mean single chunk per segment
if n_jobs == 1:
# not chunk computing
# TODO Discuss, Sam, is this something that we want to do?
# Even in single process mode, we should chunk the data to avoid loading the whole thing into memory I feel
# Am I wrong?
chunk_size = None
num_segments = recording.get_num_segments()
samples_in_larger_segment = max([recording.get_num_samples(segment) for segment in range(num_segments)])
chunk_size = samples_in_larger_segment
else:
raise ValueError("For n_jobs >1 you must specify total_memory or chunk_size or chunk_memory")

Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def write_binary_recording(
add_file_extension: bool = True,
byte_offset: int = 0,
auto_cast_uint: bool = True,
verbose: bool = False,
**job_kwargs,
):
"""
Expand All @@ -98,6 +99,8 @@ def write_binary_recording(
auto_cast_uint: bool, default: True
If True, unsigned integers are automatically cast to int if the specified dtype is signed
.. deprecated:: 0.103, use the `unsigned_to_signed` function instead.
verbose: bool
This is the verbosity of the ChunkRecordingExecutor
{}
"""
job_kwargs = fix_job_kwargs(job_kwargs)
Expand Down Expand Up @@ -138,7 +141,7 @@ def write_binary_recording(
init_func = _init_binary_worker
init_args = (recording, file_path_dict, dtype, byte_offset, cast_unsigned)
executor = ChunkRecordingExecutor(
recording, func, init_func, init_args, job_name="write_binary_recording", **job_kwargs
recording, func, init_func, init_args, job_name="write_binary_recording", verbose=verbose, **job_kwargs
)
executor.run()

Expand Down Expand Up @@ -348,9 +351,6 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=
else:
init_args = (recording, arrays, None, None, dtype, cast_unsigned)

if "verbose" in job_kwargs:
del job_kwargs["verbose"]

executor = ChunkRecordingExecutor(
recording, func, init_func, init_args, verbose=verbose, job_name="write_memory_recording", **job_kwargs
)
Expand Down
31 changes: 22 additions & 9 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from itertools import chain
import os
import json
import math
import pickle
import weakref
import shutil
Expand Down Expand Up @@ -237,7 +238,21 @@ def create(
return_scaled=True,
):
# some checks
assert sorting.sampling_frequency == recording.sampling_frequency
if sorting.sampling_frequency != recording.sampling_frequency:
if math.isclose(sorting.sampling_frequency, recording.sampling_frequency, abs_tol=1e-2, rel_tol=1e-5):
warnings.warn(
"Sorting and Recording have a small difference in sampling frequency. "
"This could be due to rounding of floats. Using the sampling frequency from the Recording."
)
# we make a copy here to change the smapling frequency
sorting = NumpySorting.from_sorting(sorting, with_metadata=True, copy_spike_vector=True)
sorting._sampling_frequency = recording.sampling_frequency
else:
raise ValueError(
f"Sorting and Recording sampling frequencies are too different: "
f"recording: {recording.sampling_frequency} - sorting: {sorting.sampling_frequency}. "
"Ensure that you are associating the correct Recording and Sorting when creating a SortingAnalyzer."
)
# check that multiple probes are non-overlapping
all_probes = recording.get_probegroup().probes
check_probe_do_not_overlap(all_probes)
Expand Down Expand Up @@ -570,9 +585,10 @@ def load_from_zarr(cls, folder, recording=None):
rec_attributes["probegroup"] = None

# sparsity
if "sparsity_mask" in zarr_root.attrs:
# sparsity = zarr_root.attrs["sparsity"]
sparsity = ChannelSparsity(zarr_root["sparsity_mask"], cls.unit_ids, rec_attributes["channel_ids"])
if "sparsity_mask" in zarr_root:
sparsity = ChannelSparsity(
np.array(zarr_root["sparsity_mask"]), sorting.unit_ids, rec_attributes["channel_ids"]
)
else:
sparsity = None

Expand Down Expand Up @@ -1581,10 +1597,6 @@ def load_data(self):
self.data[ext_data_name] = ext_data

elif self.format == "zarr":
# Alessio
# TODO: we need decide if we make a copy to memory or keep the lazy loading. For binary_folder it used to be lazy with memmap
# but this make the garbage complicated when a data is hold by a plot but the o SortingAnalyzer is delete
# lets talk
extension_group = self._get_zarr_extension_group(mode="r")
for ext_data_name in extension_group.keys():
ext_data_ = extension_group[ext_data_name]
Expand All @@ -1600,7 +1612,8 @@ def load_data(self):
elif "object" in ext_data_.attrs:
ext_data = ext_data_[0]
else:
ext_data = ext_data_
# this load in memmory
ext_data = np.array(ext_data_)
self.data[ext_data_name] = ext_data

def copy(self, new_sorting_analyzer, unit_ids=None):
Expand Down
8 changes: 8 additions & 0 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def __repr__(self):
txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}"
return txt

def __eq__(self, other):
return (
isinstance(other, ChannelSparsity)
and np.array_equal(self.channel_ids, other.channel_ids)
and np.array_equal(self.unit_ids, other.unit_ids)
and np.array_equal(self.mask, other.mask)
)

@property
def unit_id_to_channel_ids(self):
if self._unit_id_to_channel_ids is None:
Expand Down
13 changes: 10 additions & 3 deletions src/spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ def test_ensure_n_jobs():


def test_ensure_chunk_size():
recording = generate_recording(num_channels=2)
recording = generate_recording(num_channels=2, durations=[5.0, 2.5]) # This is the default value for two semgents
dtype = recording.get_dtype()
assert dtype == "float32"
# make serializable
recording = recording.save()

chunk_size = ensure_chunk_size(recording, total_memory="512M", chunk_size=None, chunk_memory=None, n_jobs=2)
assert chunk_size == 32000000
Expand All @@ -69,6 +67,15 @@ def test_ensure_chunk_size():
chunk_size = ensure_chunk_size(recording, chunk_duration="500ms")
assert chunk_size == 15000

# Test edge case to define single chunk for n_jobs=1
chunk_size = ensure_chunk_size(recording, n_jobs=1, chunk_size=None)
chunks = divide_recording_into_chunks(recording, chunk_size)
assert len(chunks) == recording.get_num_segments()
for chunk in chunks:
segment_index, start_frame, end_frame = chunk
assert start_frame == 0
assert end_frame == recording.get_num_frames(segment_index=segment_index)


def func(segment_index, start_frame, end_frame, worker_ctx):
import os
Expand Down
18 changes: 10 additions & 8 deletions src/spikeinterface/core/tests/test_recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_write_binary_recording(tmp_path):
file_paths = [tmp_path / "binary01.raw"]

# Write binary recording
job_kwargs = dict(verbose=False, n_jobs=1)
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs)
job_kwargs = dict(n_jobs=1)
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs)

# Check if written data matches original data
recorder_binary = BinaryRecordingExtractor(
Expand All @@ -64,9 +64,11 @@ def test_write_binary_recording_offset(tmp_path):
file_paths = [tmp_path / "binary01.raw"]

# Write binary recording
job_kwargs = dict(verbose=False, n_jobs=1)
job_kwargs = dict(n_jobs=1)
byte_offset = 125
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, **job_kwargs)
write_binary_recording(
recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, verbose=False, **job_kwargs
)

# Check if written data matches original data
recorder_binary = BinaryRecordingExtractor(
Expand Down Expand Up @@ -97,8 +99,8 @@ def test_write_binary_recording_parallel(tmp_path):
file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"]

# Write binary recording
job_kwargs = dict(verbose=False, n_jobs=2, chunk_memory="100k", mp_context="spawn")
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs)
job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn")
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs)

# Check if written data matches original data
recorder_binary = BinaryRecordingExtractor(
Expand Down Expand Up @@ -127,8 +129,8 @@ def test_write_binary_recording_multiple_segment(tmp_path):
file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"]

# Write binary recording
job_kwargs = dict(verbose=False, n_jobs=2, chunk_memory="100k", mp_context="spawn")
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs)
job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn")
write_binary_recording(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs)

# Check if written data matches original data
recorder_binary = BinaryRecordingExtractor(
Expand Down
5 changes: 5 additions & 0 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,15 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):

data = sorting_analyzer2.get_extension("dummy").data
assert "result_one" in data
assert isinstance(data["result_one"], str)
assert isinstance(data["result_two"], np.ndarray)
assert data["result_two"].size == original_sorting.to_spike_vector().size
assert np.array_equal(data["result_two"], sorting_analyzer.get_extension("dummy").data["result_two"])

assert sorting_analyzer2.return_scaled == sorting_analyzer.return_scaled

assert sorting_analyzer2.sparsity == sorting_analyzer.sparsity

# select unit_ids to several format
for format in ("memory", "binary_folder", "zarr"):
if format != "memory":
Expand Down
11 changes: 10 additions & 1 deletion src/spikeinterface/preprocessing/highpass_spatial_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from .filter import fix_dtype
from ..core import order_channels_by_depth, get_chunk_with_margin
from ..core.core_tools import define_function_from_class

Expand Down Expand Up @@ -47,6 +48,8 @@ class HighpassSpatialFilterRecording(BasePreprocessor):
Order of spatial butterworth filter
highpass_butter_wn : float, default: 0.01
Critical frequency (with respect to Nyquist) of spatial butterworth filter
dtype : dtype, default: None
The dtype of the output traces. If None, the dtype is the same as the input traces
Returns
-------
Expand All @@ -73,6 +76,7 @@ def __init__(
agc_window_length_s=0.1,
highpass_butter_order=3,
highpass_butter_wn=0.01,
dtype=None,
):
BasePreprocessor.__init__(self, recording)

Expand Down Expand Up @@ -117,6 +121,8 @@ def __init__(
butter_kwargs = dict(btype="highpass", N=highpass_butter_order, Wn=highpass_butter_wn)
sos_filter = scipy.signal.butter(**butter_kwargs, output="sos")

dtype = fix_dtype(recording, dtype)

for parent_segment in recording._recording_segments:
rec_segment = HighPassSpatialFilterSegment(
parent_segment,
Expand All @@ -128,6 +134,7 @@ def __init__(
sos_filter,
order_f,
order_r,
dtype=dtype,
)
self.add_recording_segment(rec_segment)

Expand Down Expand Up @@ -155,6 +162,7 @@ def __init__(
sos_filter,
order_f,
order_r,
dtype,
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.parent_recording_segment = parent_recording_segment
Expand All @@ -178,6 +186,7 @@ def __init__(
self.order_r = order_r
# get filter params
self.sos_filter = sos_filter
self.dtype = dtype

def get_traces(self, start_frame, end_frame, channel_indices):
if channel_indices is None:
Expand Down Expand Up @@ -234,7 +243,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces = traces[left_margin:-right_margin, channel_indices]
else:
traces = traces[left_margin:, channel_indices]
return traces
return traces.astype(self.dtype, copy=False)


# function for API
Expand Down
Loading

0 comments on commit eb03db8

Please sign in to comment.