Skip to content

Commit

Permalink
Merge branch 'main' into unpin_sphinx
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia authored May 21, 2024
2 parents e947d9c + 9aebd5a commit 8b3be40
Show file tree
Hide file tree
Showing 16 changed files with 283 additions and 86 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ full = [
"scikit-learn",
"networkx",
"distinctipy",
"matplotlib",
"matplotlib<3.9", # See https://github.com/SpikeInterface/spikeinterface/issues/2863
"cuda-python; platform_system != 'Darwin'",
"numba",
]
Expand Down
46 changes: 36 additions & 10 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,6 @@
from math import prod

import numpy as np
from tqdm import tqdm

from .job_tools import (
ensure_chunk_size,
ensure_n_jobs,
divide_segment_into_chunks,
fix_job_kwargs,
ChunkRecordingExecutor,
_shared_job_kwargs_doc,
)


def define_function_from_class(source_class, name):
Expand Down Expand Up @@ -447,6 +437,42 @@ def convert_bytes_to_str(byte_value: int) -> str:
return f"{byte_value:.2f} {suffixes[i]}"


_exponents = {
"k": 1e3,
"M": 1e6,
"G": 1e9,
"T": 1e12,
"P": 1e15, # Decimal (SI) prefixes
"Ki": 1024**1,
"Mi": 1024**2,
"Gi": 1024**3,
"Ti": 1024**4,
"Pi": 1024**5, # Binary (IEC) prefixes
}


def convert_string_to_bytes(memory_string: str) -> int:
"""
Convert a memory size string to the corresponding number of bytes.
Parameters:
mem (str): Memory size string (e.g., "1G", "512Mi", "2T").
Returns:
int: Number of bytes.
"""
if memory_string[-2:] in _exponents:
suffix = memory_string[-2:]
mem_value = memory_string[:-2]
else:
suffix = memory_string[-1]
mem_value = memory_string[:-1]

assert suffix in _exponents, f"Unknown suffix: {suffix}"
bytes_value = int(float(mem_value) * _exponents[suffix])
return bytes_value


def is_editable_mode() -> bool:
"""
Check if spikeinterface is installed in editable mode
Expand Down
13 changes: 8 additions & 5 deletions src/spikeinterface/core/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def set_global_tmp_folder(folder):
temp_folder_set = True


def is_set_global_tmp_folder():
def is_set_global_tmp_folder() -> bool:
"""
Check is the global path temporary folder have been manually set.
Check if the global path temporary folder have been manually set.
"""
global temp_folder_set
return temp_folder_set
Expand Down Expand Up @@ -88,9 +88,9 @@ def set_global_dataset_folder(folder):
dataset_folder_set = True


def is_set_global_dataset_folder():
def is_set_global_dataset_folder() -> bool:
"""
Check is the global path dataset folder have been manually set.
Check if the global path dataset folder has been manually set.
"""
global dataset_folder_set
return dataset_folder_set
Expand Down Expand Up @@ -138,7 +138,10 @@ def reset_global_job_kwargs():
global_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True)


def is_set_global_job_kwargs_set():
def is_set_global_job_kwargs_set() -> bool:
"""
Check if the global job kwargs have been manually set.
"""
global global_job_kwargs_set
return global_job_kwargs_set

Expand Down
40 changes: 21 additions & 19 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import platform
import os
import warnings
from spikeinterface.core.core_tools import convert_string_to_bytes

import sys
import contextlib
from tqdm.auto import tqdm

from concurrent.futures import ProcessPoolExecutor
Expand All @@ -23,13 +23,14 @@
- chunk_size: int
Number of samples per chunk
- chunk_memory: str
Memory usage for each job (e.g. "100M", "1G")
Memory usage for each job (e.g. "100M", "1G", "500MiB", "2GiB")
- total_memory: str
Total memory usage (e.g. "500M", "2G")
- chunk_duration : str or float or None
Chunk duration in s if float or with units if str (e.g. "1s", "500ms")
* n_jobs: int
Number of jobs to use. With -1 the number of jobs is the same as number of cores
* n_jobs: int | float
Number of jobs to use. With -1 the number of jobs is the same as number of cores.
Using a float between 0 and 1 will use that fraction of the total cores.
* progress_bar: bool
If True, a progress bar is printed
* mp_context: "fork" | "spawn" | None, default: None
Expand Down Expand Up @@ -60,7 +61,7 @@


def fix_job_kwargs(runtime_job_kwargs):
from .globals import get_global_job_kwargs
from .globals import get_global_job_kwargs, is_set_global_job_kwargs_set

job_kwargs = get_global_job_kwargs()

Expand Down Expand Up @@ -99,6 +100,15 @@ def fix_job_kwargs(runtime_job_kwargs):

job_kwargs["n_jobs"] = max(n_jobs, 1)

if "n_jobs" not in runtime_job_kwargs and job_kwargs["n_jobs"] == 1 and not is_set_global_job_kwargs_set():
warnings.warn(
"`n_jobs` is not set so parallel processing is disabled! "
"To speed up computations, it is recommended to set n_jobs either "
"globally (with the `spikeinterface.set_global_job_kwargs()` function) or "
"locally (with the `n_jobs` argument). Use `spikeinterface.set_global_job_kwargs?` "
"for more information about job_kwargs."
)

return job_kwargs


Expand Down Expand Up @@ -149,16 +159,6 @@ def divide_recording_into_chunks(recording, chunk_size):
return all_chunks


_exponents = {"k": 1e3, "M": 1e6, "G": 1e9}


def _mem_to_int(mem):
suffix = mem[-1]
assert suffix in _exponents
mem = int(float(mem[:-1]) * _exponents[suffix])
return mem


def ensure_n_jobs(recording, n_jobs=1):
if n_jobs == -1:
n_jobs = os.cpu_count()
Expand Down Expand Up @@ -206,9 +206,11 @@ def ensure_chunk_size(
chunk_size: int or None
size for one chunk per job
chunk_memory: str or None
must end with "k", "M" or "G"
must end with "k", "M", "G", etc for decimal units and "ki", "Mi", "Gi", etc for
binary units. (e.g. "1k", "500M", "2G", "1ki", "500Mi", "2Gi")
total_memory: str or None
must end with "k", "M" or "G"
must end with "k", "M", "G", etc for decimal units and "ki", "Mi", "Gi", etc for
binary units. (e.g. "1k", "500M", "2G", "1ki", "500Mi", "2Gi")
chunk_duration: None or float or str
Units are second if float.
If str then the str must contain units(e.g. "1s", "500ms")
Expand All @@ -219,14 +221,14 @@ def ensure_chunk_size(
elif chunk_memory is not None:
assert total_memory is None
# set by memory per worker size
chunk_memory = _mem_to_int(chunk_memory)
chunk_memory = convert_string_to_bytes(chunk_memory)
n_bytes = np.dtype(recording.get_dtype()).itemsize
num_channels = recording.get_num_channels()
chunk_size = int(chunk_memory / (num_channels * n_bytes))
elif total_memory is not None:
# clip by total memory size
n_jobs = ensure_n_jobs(recording, n_jobs=n_jobs)
total_memory = _mem_to_int(total_memory)
total_memory = convert_string_to_bytes(total_memory)
n_bytes = np.dtype(recording.get_dtype()).itemsize
num_channels = recording.get_num_channels()
chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs))
Expand Down
37 changes: 22 additions & 15 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,33 +152,40 @@ def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx):
cast_unsigned = worker_ctx["cast_unsigned"]
file = worker_ctx["file_dict"][segment_index]

# Open the memmap
# What we need is the file_path
num_channels = recording.get_num_channels()
num_frames = recording.get_num_frames(segment_index=segment_index)
shape = (num_frames, num_channels)
dtype_size_bytes = np.dtype(dtype).itemsize
data_size_bytes = dtype_size_bytes * num_frames * num_channels

# Offset (The offset needs to be multiple of the page size)
# The mmap offset is associated to be as big as possible but still a multiple of the page size
# The array offset takes care of the reminder
mmap_offset, array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY)
mmmap_length = data_size_bytes + array_offset
memmap_obj = mmap.mmap(file.fileno(), length=mmmap_length, access=mmap.ACCESS_WRITE, offset=mmap_offset)
# Calculate byte offsets for the start and end frames relative to the entire recording
start_byte = byte_offset + start_frame * num_channels * dtype_size_bytes
end_byte = byte_offset + end_frame * num_channels * dtype_size_bytes

array = np.ndarray.__new__(np.ndarray, shape=shape, dtype=dtype, buffer=memmap_obj, order="C", offset=array_offset)
# apply function
# The mmap offset must be a multiple of mmap.ALLOCATIONGRANULARITY
memmap_offset, start_offset = divmod(start_byte, mmap.ALLOCATIONGRANULARITY)
memmap_offset *= mmap.ALLOCATIONGRANULARITY

# This maps in bytes the region of the memmap that corresponds to the chunk
length = (end_byte - start_byte) + start_offset
memmap_obj = mmap.mmap(file.fileno(), length=length, access=mmap.ACCESS_WRITE, offset=memmap_offset)

# To use numpy semantics we use the array interface of the memmap object
num_frames = end_frame - start_frame
shape = (num_frames, num_channels)
memmap_array = np.ndarray(shape=shape, dtype=dtype, buffer=memmap_obj, offset=start_offset)

# Extract the traces and store them in the memmap array
traces = recording.get_traces(
start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, cast_unsigned=cast_unsigned
)

if traces.dtype != dtype:
traces = traces.astype(dtype, copy=False)
array[start_frame:end_frame, :] = traces

# Close the memmap
memmap_array[...] = traces

memmap_obj.flush()

memmap_obj.close()


write_binary_recording.__doc__ = write_binary_recording.__doc__.format(_shared_job_kwargs_doc)

Expand Down
66 changes: 62 additions & 4 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Literal, Optional

from pathlib import Path
from itertools import chain
import os
import json
import pickle
Expand Down Expand Up @@ -921,11 +922,11 @@ def compute_one_extension(self, extension_name, save=True, **kwargs):
>>> wfs = compute_waveforms(sorting_analyzer, **some_params)
"""
extension_class = get_extension_class(extension_name)

for child in _get_children_dependencies(extension_name):
self.delete_extension(child)

extension_class = get_extension_class(extension_name)

if extension_class.need_job_kwargs:
params, job_kwargs = split_job_kwargs(kwargs)
else:
Expand Down Expand Up @@ -978,14 +979,17 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
>>> sorting_analyzer.compute_several_extensions({"waveforms": {"ms_before": 1.2}, "templates" : {"operators": ["average", "std"]}})
"""
for extension_name in extensions.keys():

sorted_extensions = _sort_extensions_by_dependency(extensions)

for extension_name in sorted_extensions.keys():
for child in _get_children_dependencies(extension_name):
self.delete_extension(child)

extensions_with_pipeline = {}
extensions_without_pipeline = {}
extensions_post_pipeline = {}
for extension_name, extension_params in extensions.items():
for extension_name, extension_params in sorted_extensions.items():
if extension_name == "quality_metrics":
# PATCH: the quality metric is computed after the pipeline, since some of the metrics optionally require
# the output of the pipeline extensions (e.g., spike_amplitudes, spike_locations).
Expand All @@ -1009,6 +1013,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
all_nodes = []
result_routage = []
extension_instances = {}

for extension_name, extension_params in extensions_with_pipeline.items():
extension_class = get_extension_class(extension_name)
assert self.has_recording(), f"Extension {extension_name} need the recording"
Expand All @@ -1024,6 +1029,7 @@ def compute_several_extensions(self, extensions, save=True, **job_kwargs):
all_nodes.extend(nodes)

job_name = "Compute : " + " + ".join(extensions_with_pipeline.keys())

results = run_node_pipeline(
self.recording,
all_nodes,
Expand Down Expand Up @@ -1191,6 +1197,58 @@ def get_default_extension_params(self, extension_name: str):
return get_default_analyzer_extension_params(extension_name)


def _sort_extensions_by_dependency(extensions):
"""
Sorts a dictionary of extensions so that the parents of each extension are on the "left" of their children.
Assumes there is a valid ordering of the included extensions.
Parameters
----------
extensions: dict
A dict of extensions.
Returns
-------
sorted_extensions: dict
A dict of extensions, with the parents on the left of their children.
"""

extensions_list = list(extensions.keys())
extension_params = list(extensions.values())

i = 0
while i < len(extensions_list):

extension = extensions_list[i]
dependencies = get_extension_class(extension).depend_on

# Split cases with an "or" in them, and flatten into a list
dependencies = list(chain.from_iterable([dependency.split("|") for dependency in dependencies]))

# Should only iterate if nothing has happened.
# Otherwise, should check the dependency which has just been moved => at position i
did_nothing = True
for dependency in dependencies:

# if dependency is on the right, move it left of the current dependency
if dependency in extensions_list[i:]:

dependency_arg = extensions_list.index(dependency)

extension_params.pop(dependency_arg)
extension_params.insert(i, extensions[dependency])

extensions_list.pop(dependency_arg)
extensions_list.insert(i, dependency)

did_nothing = False

if did_nothing:
i += 1

return dict(zip(extensions_list, extension_params))


global _possible_extensions
_possible_extensions = []

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p
return_scaled = templates_or_sorting_analyzer.return_scaled
elif isinstance(templates_or_sorting_analyzer, Templates):
assert noise_levels is not None
return_scaled = True
return_scaled = templates_or_sorting_analyzer.is_scaled

mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool")

Expand Down Expand Up @@ -369,7 +369,7 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None):
return_scaled = templates_or_sorting_analyzer.return_scaled
elif isinstance(templates_or_sorting_analyzer, Templates):
assert noise_levels is not None
return_scaled = True
return_scaled = templates_or_sorting_analyzer.is_scaled

from .template_tools import get_dense_templates_array

Expand Down
Loading

0 comments on commit 8b3be40

Please sign in to comment.