From d45dafc6f703477c4220244a48f645b696bf27b4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Dec 2023 13:05:07 +0100 Subject: [PATCH 1/7] Move some function torecording_tools --- src/spikeinterface/core/__init__.py | 6 +- src/spikeinterface/core/baserecording.py | 3 + .../core/binaryrecordingextractor.py | 3 +- src/spikeinterface/core/core_tools.py | 600 ----------------- src/spikeinterface/core/recording_tools.py | 621 ++++++++++++++++++ .../core/tests/test_core_tools.py | 150 ----- .../core/tests/test_recording_tools.py | 165 ++++- .../extractors/mdaextractors.py | 2 +- .../extractors/shybridextractors.py | 4 +- src/spikeinterface/sorters/external/hdsort.py | 2 +- 10 files changed, 795 insertions(+), 761 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 9f91c8759e..66f6c5313b 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -79,14 +79,14 @@ # tools from .core_tools import ( - write_binary_recording, - write_to_h5_dataset_format, - write_binary_recording, read_python, write_python, ) from .job_tools import ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs from .recording_tools import ( + write_binary_recording, + write_to_h5_dataset_format, + write_binary_recording, get_random_data_chunks, get_channel_distances, get_closest_channels, diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index c112447c35..e33a64f32c 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -12,9 +12,12 @@ check_json, convert_bytes_to_str, convert_seconds_to_str, +) +from .recording_tools import ( write_binary_recording, write_memory_recording, write_traces_to_zarr, + ) from .job_tools import split_job_kwargs diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index d8c6512a38..ce8be46fab 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -6,7 +6,8 @@ import numpy as np from .baserecording import BaseRecording, BaseRecordingSegment -from .core_tools import write_binary_recording, define_function_from_class +from .core_tools import define_function_from_class +from .recording_tools import write_binary_recording from .job_tools import _shared_job_kwargs_doc diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 5892c62b62..44a712dfed 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -5,8 +5,6 @@ import datetime import json from copy import deepcopy -import gc -import mmap import inspect import numpy as np @@ -160,251 +158,8 @@ def add_suffix(file_path, possible_suffix): return file_path -def read_binary_recording(file, num_channels, dtype, time_axis=0, offset=0): - """ - Read binary .bin or .dat file. - - Parameters - ---------- - file: str - File name - num_channels: int - Number of channels - dtype: dtype - dtype of the file - time_axis: 0 or 1, default: 0 - If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. - If 1, the traces shape (nb_channel, nb_sample) is kept in the file. - offset: int, default: 0 - number of offset bytes - - """ - num_channels = int(num_channels) - with Path(file).open() as f: - nsamples = (os.fstat(f.fileno()).st_size - offset) // (num_channels * np.dtype(dtype).itemsize) - if time_axis == 0: - samples = np.memmap(file, np.dtype(dtype), mode="r", offset=offset, shape=(nsamples, num_channels)) - else: - samples = np.memmap(file, np.dtype(dtype), mode="r", offset=offset, shape=(num_channels, nsamples)).T - return samples - - -# used by write_binary_recording + ChunkRecordingExecutor -def _init_binary_worker(recording, file_path_dict, dtype, byte_offest, cast_unsigned): - # create a local dict per worker - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["byte_offset"] = byte_offest - worker_ctx["dtype"] = np.dtype(dtype) - worker_ctx["cast_unsigned"] = cast_unsigned - - file_dict = {segment_index: open(file_path, "r+") for segment_index, file_path in file_path_dict.items()} - worker_ctx["file_dict"] = file_dict - - return worker_ctx - - -def write_binary_recording( - recording, - file_paths, - dtype=None, - add_file_extension=True, - byte_offset=0, - auto_cast_uint=True, - **job_kwargs, -): - """ - Save the trace of a recording extractor in several binary .dat format. - - Note : - time_axis is always 0 (contrary to previous version. - to get time_axis=1 (which is a bad idea) use `write_binary_recording_file_handle()` - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object to be saved in .dat format - file_path: str or list[str] - The path to the file. - dtype: dtype or None, default: None - Type of the saved data - If True, file the ".raw" file extension is added if the file name is not a "raw", "bin", or "dat" - byte_offset: int, default: 0 - Offset in bytes to for the binary file (e.g. to write a header) - auto_cast_uint: bool, default: True - If True, unsigned integers are automatically cast to int if the specified dtype is signed - {} - """ - job_kwargs = fix_job_kwargs(job_kwargs) - - file_path_list = [file_paths] if not isinstance(file_paths, list) else file_paths - num_segments = recording.get_num_segments() - if len(file_path_list) != num_segments: - raise ValueError("'file_paths' must be a list of the same size as the number of segments in the recording") - - file_path_list = [Path(file_path) for file_path in file_path_list] - if add_file_extension: - file_path_list = [add_suffix(file_path, ["raw", "bin", "dat"]) for file_path in file_path_list] - - dtype = dtype if dtype is not None else recording.get_dtype() - cast_unsigned = False - if auto_cast_uint: - cast_unsigned = determine_cast_unsigned(recording, dtype) - - dtype_size_bytes = np.dtype(dtype).itemsize - num_channels = recording.get_num_channels() - - file_path_dict = {segment_index: file_path for segment_index, file_path in enumerate(file_path_list)} - for segment_index, file_path in file_path_dict.items(): - num_frames = recording.get_num_frames(segment_index=segment_index) - data_size_bytes = dtype_size_bytes * num_frames * num_channels - file_size_bytes = data_size_bytes + byte_offset - - file = open(file_path, "wb+") - file.truncate(file_size_bytes) - file.close() - assert Path(file_path).is_file() - - # use executor (loop or workers) - func = _write_binary_chunk - init_func = _init_binary_worker - init_args = (recording, file_path_dict, dtype, byte_offset, cast_unsigned) - executor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name="write_binary_recording", **job_kwargs - ) - executor.run() - - -# used by write_binary_recording + ChunkRecordingExecutor -def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): - # recover variables of the worker - recording = worker_ctx["recording"] - dtype = worker_ctx["dtype"] - byte_offset = worker_ctx["byte_offset"] - cast_unsigned = worker_ctx["cast_unsigned"] - file = worker_ctx["file_dict"][segment_index] - - # Open the memmap - # What we need is the file_path - num_channels = recording.get_num_channels() - num_frames = recording.get_num_frames(segment_index=segment_index) - shape = (num_frames, num_channels) - dtype_size_bytes = np.dtype(dtype).itemsize - data_size_bytes = dtype_size_bytes * num_frames * num_channels - - # Offset (The offset needs to be multiple of the page size) - # The mmap offset is associated to be as big as possible but still a multiple of the page size - # The array offset takes care of the reminder - mmap_offset, array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY) - mmmap_length = data_size_bytes + array_offset - memmap_obj = mmap.mmap(file.fileno(), length=mmmap_length, access=mmap.ACCESS_WRITE, offset=mmap_offset) - - array = np.ndarray.__new__(np.ndarray, shape=shape, dtype=dtype, buffer=memmap_obj, order="C", offset=array_offset) - # apply function - traces = recording.get_traces( - start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, cast_unsigned=cast_unsigned - ) - if traces.dtype != dtype: - traces = traces.astype(dtype, copy=False) - array[start_frame:end_frame, :] = traces - - # Close the memmap - memmap_obj.flush() - - -write_binary_recording.__doc__ = write_binary_recording.__doc__.format(_shared_job_kwargs_doc) - - -def write_binary_recording_file_handle( - recording, file_handle=None, time_axis=0, dtype=None, byte_offset=0, verbose=False, **job_kwargs -): - """ - Old variant version of write_binary_recording with one file handle. - Can be useful in some case ??? - Not used anymore at the moment. - - @ SAM useful for writing with time_axis=1! - """ - assert file_handle is not None - assert recording.get_num_segments() == 1, "If file_handle is given then only deals with one segment" - - if dtype is None: - dtype = recording.get_dtype() - - job_kwargs = fix_job_kwargs(job_kwargs) - chunk_size = ensure_chunk_size(recording, **job_kwargs) - - if chunk_size is not None and time_axis == 1: - print("Chunking disabled due to 'time_axis' == 1") - chunk_size = None - - if chunk_size is None: - # no chunking - traces = recording.get_traces(segment_index=0) - if time_axis == 1: - traces = traces.T - if dtype is not None: - traces = traces.astype(dtype, copy=False) - traces.tofile(file_handle) - else: - num_frames = recording.get_num_samples(segment_index=0) - chunks = divide_segment_into_chunks(num_frames, chunk_size) - - for start_frame, end_frame in chunks: - traces = recording.get_traces(segment_index=0, start_frame=start_frame, end_frame=end_frame) - if time_axis == 1: - traces = traces.T - if dtype is not None: - traces = traces.astype(dtype, copy=False) - file_handle.write(traces.tobytes()) - - -# used by write_memory_recording -def _init_memory_worker(recording, arrays, shm_names, shapes, dtype, cast_unsigned): - # create a local dict per worker - worker_ctx = {} - if isinstance(recording, dict): - from spikeinterface.core import load_extractor - - worker_ctx["recording"] = load_extractor(recording) - else: - worker_ctx["recording"] = recording - - worker_ctx["dtype"] = np.dtype(dtype) - - if arrays is None: - # create it from share memory name - from multiprocessing.shared_memory import SharedMemory - - arrays = [] - # keep shm alive - worker_ctx["shms"] = [] - for i in range(len(shm_names)): - shm = SharedMemory(shm_names[i]) - worker_ctx["shms"].append(shm) - arr = np.ndarray(shape=shapes[i], dtype=dtype, buffer=shm.buf) - arrays.append(arr) - worker_ctx["arrays"] = arrays - worker_ctx["cast_unsigned"] = cast_unsigned - return worker_ctx - - -# used by write_memory_recording -def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx): - # recover variables of the worker - recording = worker_ctx["recording"] - dtype = worker_ctx["dtype"] - arr = worker_ctx["arrays"][segment_index] - cast_unsigned = worker_ctx["cast_unsigned"] - - # apply function - traces = recording.get_traces( - start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, cast_unsigned=cast_unsigned - ) - traces = traces.astype(dtype, copy=False) - arr[start_frame:end_frame, :] = traces def make_shared_array(shape, dtype): @@ -419,361 +174,6 @@ def make_shared_array(shape, dtype): return arr, shm -def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=True, **job_kwargs): - """ - Save the traces into numpy arrays (memory). - try to use the SharedMemory introduce in py3.8 if n_jobs > 1 - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object to be saved in .dat format - dtype: dtype, default: None - Type of the saved data - verbose: bool, default: False - If True, output is verbose (when chunks are used) - auto_cast_uint: bool, default: True - If True, unsigned integers are automatically cast to int if the specified dtype is signed - {} - - Returns - --------- - arrays: one arrays per segment - """ - job_kwargs = fix_job_kwargs(job_kwargs) - - if dtype is None: - dtype = recording.get_dtype() - if auto_cast_uint: - cast_unsigned = determine_cast_unsigned(recording, dtype) - else: - cast_unsigned = False - - # create sharedmmep - arrays = [] - shm_names = [] - shapes = [] - - n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1)) - for segment_index in range(recording.get_num_segments()): - num_frames = recording.get_num_samples(segment_index) - num_channels = recording.get_num_channels() - shape = (num_frames, num_channels) - shapes.append(shape) - if n_jobs > 1: - arr, shm = make_shared_array(shape, dtype) - shm_names.append(shm.name) - else: - arr = np.zeros(shape, dtype=dtype) - arrays.append(arr) - - # use executor (loop or workers) - func = _write_memory_chunk - init_func = _init_memory_worker - if n_jobs > 1: - init_args = (recording, None, shm_names, shapes, dtype, cast_unsigned) - else: - init_args = (recording, arrays, None, None, dtype, cast_unsigned) - - executor = ChunkRecordingExecutor( - recording, func, init_func, init_args, verbose=verbose, job_name="write_memory_recording", **job_kwargs - ) - executor.run() - - return arrays - - -write_memory_recording.__doc__ = write_memory_recording.__doc__.format(_shared_job_kwargs_doc) - - -def write_to_h5_dataset_format( - recording, - dataset_path, - segment_index, - save_path=None, - file_handle=None, - time_axis=0, - single_axis=False, - dtype=None, - chunk_size=None, - chunk_memory="500M", - verbose=False, - auto_cast_uint=True, - return_scaled=False, -): - """ - Save the traces of a recording extractor in an h5 dataset. - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object to be saved in .dat format - dataset_path: str - Path to dataset in h5 file (e.g. "/dataset") - segment_index: int - index of segment - save_path: str, default: None - The path to the file. - file_handle: file handle, default: None - The file handle to dump data. This can be used to append data to an header. In case file_handle is given, - the file is NOT closed after writing the binary data. - time_axis: 0 or 1, default: 0 - If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. - If 1, the traces shape (nb_channel, nb_sample) is kept in the file. - single_axis: bool, default: False - If True, a single-channel recording is saved as a one dimensional array - dtype: dtype, default: None - Type of the saved data - chunk_size: None or int, default: None - Number of chunks to save the file in. This avoid to much memory consumption for big files. - If None and "chunk_memory" is given, the file is saved in chunks of "chunk_memory" MB - chunk_memory: None or str, default: "500M" - Chunk size in bytes must endswith "k", "M" or "G" - verbose: bool, default: False - If True, output is verbose (when chunks are used) - auto_cast_uint: bool, default: True - If True, unsigned integers are automatically cast to int if the specified dtype is signed - return_scaled : bool, default: False - If True and the recording has scaling (gain_to_uV and offset_to_uV properties), - traces are dumped to uV - """ - import h5py - - # ~ assert HAVE_H5, "To write to h5 you need to install h5py: pip install h5py" - assert save_path is not None or file_handle is not None, "Provide 'save_path' or 'file handle'" - - if save_path is not None: - save_path = Path(save_path) - if save_path.suffix == "": - # when suffix is already raw/bin/dat do not change it. - save_path = save_path.parent / (save_path.name + ".h5") - - num_channels = recording.get_num_channels() - num_frames = recording.get_num_frames(segment_index=0) - - if file_handle is not None: - assert isinstance(file_handle, h5py.File) - else: - file_handle = h5py.File(save_path, "w") - - if dtype is None: - dtype_file = recording.get_dtype() - else: - dtype_file = dtype - if auto_cast_uint: - cast_unsigned = determine_cast_unsigned(recording, dtype) - else: - cast_unsigned = False - - if single_axis: - shape = (num_frames,) - else: - if time_axis == 0: - shape = (num_frames, num_channels) - else: - shape = (num_channels, num_frames) - - dset = file_handle.create_dataset(dataset_path, shape=shape, dtype=dtype_file) - - chunk_size = ensure_chunk_size(recording, chunk_size=chunk_size, chunk_memory=chunk_memory, n_jobs=1) - - if chunk_size is None: - traces = recording.get_traces(cast_unsigned=cast_unsigned, return_scaled=return_scaled) - if dtype is not None: - traces = traces.astype(dtype_file, copy=False) - if time_axis == 1: - traces = traces.T - if single_axis: - dset[:] = traces[:, 0] - else: - dset[:] = traces - else: - chunk_start = 0 - # chunk size is not None - n_chunk = num_frames // chunk_size - if num_frames % chunk_size > 0: - n_chunk += 1 - if verbose: - chunks = tqdm(range(n_chunk), ascii=True, desc="Writing to .h5 file") - else: - chunks = range(n_chunk) - for i in chunks: - traces = recording.get_traces( - segment_index=segment_index, - start_frame=i * chunk_size, - end_frame=min((i + 1) * chunk_size, num_frames), - cast_unsigned=cast_unsigned, - return_scaled=return_scaled, - ) - chunk_frames = traces.shape[0] - if dtype is not None: - traces = traces.astype(dtype_file, copy=False) - if single_axis: - dset[chunk_start : chunk_start + chunk_frames] = traces[:, 0] - else: - if time_axis == 0: - dset[chunk_start : chunk_start + chunk_frames, :] = traces - else: - dset[:, chunk_start : chunk_start + chunk_frames] = traces.T - - chunk_start += chunk_frames - - if save_path is not None: - file_handle.close() - return save_path - - -def write_traces_to_zarr( - recording, - zarr_root, - zarr_path, - storage_options, - dataset_paths, - channel_chunk_size=None, - dtype=None, - compressor=None, - filters=None, - verbose=False, - auto_cast_uint=True, - **job_kwargs, -): - """ - Save the trace of a recording extractor in several zarr format. - - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object to be saved in .dat format - zarr_root: zarr.Group - The zarr root - zarr_path: str or Path - The path to the zarr file - storage_options: dict or None - Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. - dataset_paths: list - List of paths to traces datasets in the zarr group - channel_chunk_size: int or None, default: None (chunking in time only) - Channels per chunk - dtype: dtype, default: None - Type of the saved data - compressor: zarr compressor or None, default: None - Zarr compressor - filters: list, default: None - List of zarr filters - verbose: bool, default: False - If True, output is verbose (when chunks are used) - auto_cast_uint: bool, default: True - If True, unsigned integers are automatically cast to int if the specified dtype is signed - {} - """ - assert dataset_paths is not None, "Provide 'file_path'" - - if not isinstance(dataset_paths, list): - dataset_paths = [dataset_paths] - assert len(dataset_paths) == recording.get_num_segments() - - if dtype is None: - dtype = recording.get_dtype() - if auto_cast_uint: - cast_unsigned = determine_cast_unsigned(recording, dtype) - else: - cast_unsigned = False - - job_kwargs = fix_job_kwargs(job_kwargs) - chunk_size = ensure_chunk_size(recording, **job_kwargs) - - # create zarr datasets files - for segment_index in range(recording.get_num_segments()): - num_frames = recording.get_num_samples(segment_index) - num_channels = recording.get_num_channels() - dset_name = dataset_paths[segment_index] - shape = (num_frames, num_channels) - _ = zarr_root.create_dataset( - name=dset_name, - shape=shape, - chunks=(chunk_size, channel_chunk_size), - dtype=dtype, - filters=filters, - compressor=compressor, - ) - # synchronizer=zarr.ThreadSynchronizer()) - - # use executor (loop or workers) - func = _write_zarr_chunk - init_func = _init_zarr_worker - init_args = (recording, zarr_path, storage_options, dataset_paths, dtype, cast_unsigned) - executor = ChunkRecordingExecutor( - recording, func, init_func, init_args, verbose=verbose, job_name="write_zarr_recording", **job_kwargs - ) - executor.run() - - -# used by write_zarr_recording + ChunkRecordingExecutor -def _init_zarr_worker(recording, zarr_path, storage_options, dataset_paths, dtype, cast_unsigned): - import zarr - - # create a local dict per worker - worker_ctx = {} - if isinstance(recording, dict): - from spikeinterface.core import load_extractor - - worker_ctx["recording"] = load_extractor(recording) - else: - worker_ctx["recording"] = recording - - # reload root and datasets - if storage_options is None: - if isinstance(zarr_path, str): - zarr_path_init = zarr_path - zarr_path = Path(zarr_path) - else: - zarr_path_init = str(zarr_path) - else: - zarr_path_init = zarr_path - - root = zarr.open(zarr_path_init, mode="r+", storage_options=storage_options) - zarr_datasets = [] - for dset_name in dataset_paths: - z = root[dset_name] - zarr_datasets.append(z) - worker_ctx["zarr_datasets"] = zarr_datasets - worker_ctx["dtype"] = np.dtype(dtype) - worker_ctx["cast_unsigned"] = cast_unsigned - - return worker_ctx - - -# used by write_zarr_recording + ChunkRecordingExecutor -def _write_zarr_chunk(segment_index, start_frame, end_frame, worker_ctx): - # recover variables of the worker - recording = worker_ctx["recording"] - dtype = worker_ctx["dtype"] - zarr_dataset = worker_ctx["zarr_datasets"][segment_index] - cast_unsigned = worker_ctx["cast_unsigned"] - - # apply function - traces = recording.get_traces( - start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, cast_unsigned=cast_unsigned - ) - traces = traces.astype(dtype, copy=False) - zarr_dataset[start_frame:end_frame, :] = traces - - # fix memory leak by forcing garbage collection - del traces - gc.collect() - - -def determine_cast_unsigned(recording, dtype): - recording_dtype = np.dtype(recording.get_dtype()) - - if np.dtype(dtype) != recording_dtype and recording_dtype.kind == "u" and np.dtype(dtype).kind == "i": - cast_unsigned = True - else: - cast_unsigned = False - return cast_unsigned - - def is_dict_extractor(d): """ Check if a dict describe an extractor. diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 6768339eff..e077510696 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -1,9 +1,630 @@ from copy import deepcopy from typing import Literal import warnings +from pathlib import Path +import gc +import mmap + import numpy as np +from .core_tools import add_suffix, make_shared_array +from .job_tools import ( + ensure_chunk_size, + ensure_n_jobs, + divide_segment_into_chunks, + fix_job_kwargs, + ChunkRecordingExecutor, + _shared_job_kwargs_doc, +) + + + +def read_binary_recording(file, num_channels, dtype, time_axis=0, offset=0): + """ + Read binary .bin or .dat file. + + Parameters + ---------- + file: str + File name + num_channels: int + Number of channels + dtype: dtype + dtype of the file + time_axis: 0 or 1, default: 0 + If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. + If 1, the traces shape (nb_channel, nb_sample) is kept in the file. + offset: int, default: 0 + number of offset bytes + + """ + # TODO change this function to read_binary_traces() because this name is confusing + num_channels = int(num_channels) + with Path(file).open() as f: + nsamples = (os.fstat(f.fileno()).st_size - offset) // (num_channels * np.dtype(dtype).itemsize) + if time_axis == 0: + samples = np.memmap(file, np.dtype(dtype), mode="r", offset=offset, shape=(nsamples, num_channels)) + else: + samples = np.memmap(file, np.dtype(dtype), mode="r", offset=offset, shape=(num_channels, nsamples)).T + return samples + + +# used by write_binary_recording + ChunkRecordingExecutor +def _init_binary_worker(recording, file_path_dict, dtype, byte_offest, cast_unsigned): + # create a local dict per worker + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["byte_offset"] = byte_offest + worker_ctx["dtype"] = np.dtype(dtype) + worker_ctx["cast_unsigned"] = cast_unsigned + + file_dict = {segment_index: open(file_path, "r+") for segment_index, file_path in file_path_dict.items()} + worker_ctx["file_dict"] = file_dict + + return worker_ctx + + +def write_binary_recording( + recording, + file_paths, + dtype=None, + add_file_extension=True, + byte_offset=0, + auto_cast_uint=True, + **job_kwargs, +): + """ + Save the trace of a recording extractor in several binary .dat format. + + Note : + time_axis is always 0 (contrary to previous version. + to get time_axis=1 (which is a bad idea) use `write_binary_recording_file_handle()` + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object to be saved in .dat format + file_path: str or list[str] + The path to the file. + dtype: dtype or None, default: None + Type of the saved data + If True, file the ".raw" file extension is added if the file name is not a "raw", "bin", or "dat" + byte_offset: int, default: 0 + Offset in bytes to for the binary file (e.g. to write a header) + auto_cast_uint: bool, default: True + If True, unsigned integers are automatically cast to int if the specified dtype is signed + {} + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + file_path_list = [file_paths] if not isinstance(file_paths, list) else file_paths + num_segments = recording.get_num_segments() + if len(file_path_list) != num_segments: + raise ValueError("'file_paths' must be a list of the same size as the number of segments in the recording") + + file_path_list = [Path(file_path) for file_path in file_path_list] + if add_file_extension: + file_path_list = [add_suffix(file_path, ["raw", "bin", "dat"]) for file_path in file_path_list] + + dtype = dtype if dtype is not None else recording.get_dtype() + cast_unsigned = False + if auto_cast_uint: + cast_unsigned = determine_cast_unsigned(recording, dtype) + + dtype_size_bytes = np.dtype(dtype).itemsize + num_channels = recording.get_num_channels() + + file_path_dict = {segment_index: file_path for segment_index, file_path in enumerate(file_path_list)} + for segment_index, file_path in file_path_dict.items(): + num_frames = recording.get_num_frames(segment_index=segment_index) + data_size_bytes = dtype_size_bytes * num_frames * num_channels + file_size_bytes = data_size_bytes + byte_offset + + file = open(file_path, "wb+") + file.truncate(file_size_bytes) + file.close() + assert Path(file_path).is_file() + + # use executor (loop or workers) + func = _write_binary_chunk + init_func = _init_binary_worker + init_args = (recording, file_path_dict, dtype, byte_offset, cast_unsigned) + executor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name="write_binary_recording", **job_kwargs + ) + executor.run() + + +# used by write_binary_recording + ChunkRecordingExecutor +def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + recording = worker_ctx["recording"] + dtype = worker_ctx["dtype"] + byte_offset = worker_ctx["byte_offset"] + cast_unsigned = worker_ctx["cast_unsigned"] + file = worker_ctx["file_dict"][segment_index] + + # Open the memmap + # What we need is the file_path + num_channels = recording.get_num_channels() + num_frames = recording.get_num_frames(segment_index=segment_index) + shape = (num_frames, num_channels) + dtype_size_bytes = np.dtype(dtype).itemsize + data_size_bytes = dtype_size_bytes * num_frames * num_channels + + # Offset (The offset needs to be multiple of the page size) + # The mmap offset is associated to be as big as possible but still a multiple of the page size + # The array offset takes care of the reminder + mmap_offset, array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY) + mmmap_length = data_size_bytes + array_offset + memmap_obj = mmap.mmap(file.fileno(), length=mmmap_length, access=mmap.ACCESS_WRITE, offset=mmap_offset) + + array = np.ndarray.__new__(np.ndarray, shape=shape, dtype=dtype, buffer=memmap_obj, order="C", offset=array_offset) + # apply function + traces = recording.get_traces( + start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, cast_unsigned=cast_unsigned + ) + if traces.dtype != dtype: + traces = traces.astype(dtype, copy=False) + array[start_frame:end_frame, :] = traces + + # Close the memmap + memmap_obj.flush() + + +write_binary_recording.__doc__ = write_binary_recording.__doc__.format(_shared_job_kwargs_doc) + + +def write_binary_recording_file_handle( + recording, file_handle=None, time_axis=0, dtype=None, byte_offset=0, verbose=False, **job_kwargs +): + """ + Old variant version of write_binary_recording with one file handle. + Can be useful in some case ??? + Not used anymore at the moment. + + @ SAM useful for writing with time_axis=1! + """ + assert file_handle is not None + assert recording.get_num_segments() == 1, "If file_handle is given then only deals with one segment" + + if dtype is None: + dtype = recording.get_dtype() + + job_kwargs = fix_job_kwargs(job_kwargs) + chunk_size = ensure_chunk_size(recording, **job_kwargs) + + if chunk_size is not None and time_axis == 1: + print("Chunking disabled due to 'time_axis' == 1") + chunk_size = None + + if chunk_size is None: + # no chunking + traces = recording.get_traces(segment_index=0) + if time_axis == 1: + traces = traces.T + if dtype is not None: + traces = traces.astype(dtype, copy=False) + traces.tofile(file_handle) + else: + num_frames = recording.get_num_samples(segment_index=0) + chunks = divide_segment_into_chunks(num_frames, chunk_size) + + for start_frame, end_frame in chunks: + traces = recording.get_traces(segment_index=0, start_frame=start_frame, end_frame=end_frame) + if time_axis == 1: + traces = traces.T + if dtype is not None: + traces = traces.astype(dtype, copy=False) + file_handle.write(traces.tobytes()) + + +# used by write_memory_recording +def _init_memory_worker(recording, arrays, shm_names, shapes, dtype, cast_unsigned): + # create a local dict per worker + worker_ctx = {} + if isinstance(recording, dict): + from spikeinterface.core import load_extractor + + worker_ctx["recording"] = load_extractor(recording) + else: + worker_ctx["recording"] = recording + + worker_ctx["dtype"] = np.dtype(dtype) + + if arrays is None: + # create it from share memory name + from multiprocessing.shared_memory import SharedMemory + + arrays = [] + # keep shm alive + worker_ctx["shms"] = [] + for i in range(len(shm_names)): + shm = SharedMemory(shm_names[i]) + worker_ctx["shms"].append(shm) + arr = np.ndarray(shape=shapes[i], dtype=dtype, buffer=shm.buf) + arrays.append(arr) + + worker_ctx["arrays"] = arrays + worker_ctx["cast_unsigned"] = cast_unsigned + + return worker_ctx + + +# used by write_memory_recording +def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + recording = worker_ctx["recording"] + dtype = worker_ctx["dtype"] + arr = worker_ctx["arrays"][segment_index] + cast_unsigned = worker_ctx["cast_unsigned"] + + # apply function + traces = recording.get_traces( + start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, cast_unsigned=cast_unsigned + ) + traces = traces.astype(dtype, copy=False) + arr[start_frame:end_frame, :] = traces + + + +def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=True, **job_kwargs): + """ + Save the traces into numpy arrays (memory). + try to use the SharedMemory introduce in py3.8 if n_jobs > 1 + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object to be saved in .dat format + dtype: dtype, default: None + Type of the saved data + verbose: bool, default: False + If True, output is verbose (when chunks are used) + auto_cast_uint: bool, default: True + If True, unsigned integers are automatically cast to int if the specified dtype is signed + {} + + Returns + --------- + arrays: one arrays per segment + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + if dtype is None: + dtype = recording.get_dtype() + if auto_cast_uint: + cast_unsigned = determine_cast_unsigned(recording, dtype) + else: + cast_unsigned = False + + # create sharedmmep + arrays = [] + shm_names = [] + shapes = [] + + n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1)) + for segment_index in range(recording.get_num_segments()): + num_frames = recording.get_num_samples(segment_index) + num_channels = recording.get_num_channels() + shape = (num_frames, num_channels) + shapes.append(shape) + if n_jobs > 1: + arr, shm = make_shared_array(shape, dtype) + shm_names.append(shm.name) + else: + arr = np.zeros(shape, dtype=dtype) + arrays.append(arr) + + # use executor (loop or workers) + func = _write_memory_chunk + init_func = _init_memory_worker + if n_jobs > 1: + init_args = (recording, None, shm_names, shapes, dtype, cast_unsigned) + else: + init_args = (recording, arrays, None, None, dtype, cast_unsigned) + + executor = ChunkRecordingExecutor( + recording, func, init_func, init_args, verbose=verbose, job_name="write_memory_recording", **job_kwargs + ) + executor.run() + + return arrays + + +write_memory_recording.__doc__ = write_memory_recording.__doc__.format(_shared_job_kwargs_doc) + + + +def write_to_h5_dataset_format( + recording, + dataset_path, + segment_index, + save_path=None, + file_handle=None, + time_axis=0, + single_axis=False, + dtype=None, + chunk_size=None, + chunk_memory="500M", + verbose=False, + auto_cast_uint=True, + return_scaled=False, +): + """ + Save the traces of a recording extractor in an h5 dataset. + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object to be saved in .dat format + dataset_path: str + Path to dataset in h5 file (e.g. "/dataset") + segment_index: int + index of segment + save_path: str, default: None + The path to the file. + file_handle: file handle, default: None + The file handle to dump data. This can be used to append data to an header. In case file_handle is given, + the file is NOT closed after writing the binary data. + time_axis: 0 or 1, default: 0 + If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. + If 1, the traces shape (nb_channel, nb_sample) is kept in the file. + single_axis: bool, default: False + If True, a single-channel recording is saved as a one dimensional array + dtype: dtype, default: None + Type of the saved data + chunk_size: None or int, default: None + Number of chunks to save the file in. This avoid to much memory consumption for big files. + If None and "chunk_memory" is given, the file is saved in chunks of "chunk_memory" MB + chunk_memory: None or str, default: "500M" + Chunk size in bytes must endswith "k", "M" or "G" + verbose: bool, default: False + If True, output is verbose (when chunks are used) + auto_cast_uint: bool, default: True + If True, unsigned integers are automatically cast to int if the specified dtype is signed + return_scaled : bool, default: False + If True and the recording has scaling (gain_to_uV and offset_to_uV properties), + traces are dumped to uV + """ + # TODO change the name to write_traces_to_h5_dataset_format() because the name is confusing + import h5py + + # ~ assert HAVE_H5, "To write to h5 you need to install h5py: pip install h5py" + assert save_path is not None or file_handle is not None, "Provide 'save_path' or 'file handle'" + + if save_path is not None: + save_path = Path(save_path) + if save_path.suffix == "": + # when suffix is already raw/bin/dat do not change it. + save_path = save_path.parent / (save_path.name + ".h5") + + num_channels = recording.get_num_channels() + num_frames = recording.get_num_frames(segment_index=0) + + if file_handle is not None: + assert isinstance(file_handle, h5py.File) + else: + file_handle = h5py.File(save_path, "w") + + if dtype is None: + dtype_file = recording.get_dtype() + else: + dtype_file = dtype + if auto_cast_uint: + cast_unsigned = determine_cast_unsigned(recording, dtype) + else: + cast_unsigned = False + + if single_axis: + shape = (num_frames,) + else: + if time_axis == 0: + shape = (num_frames, num_channels) + else: + shape = (num_channels, num_frames) + + dset = file_handle.create_dataset(dataset_path, shape=shape, dtype=dtype_file) + + chunk_size = ensure_chunk_size(recording, chunk_size=chunk_size, chunk_memory=chunk_memory, n_jobs=1) + + if chunk_size is None: + traces = recording.get_traces(cast_unsigned=cast_unsigned, return_scaled=return_scaled) + if dtype is not None: + traces = traces.astype(dtype_file, copy=False) + if time_axis == 1: + traces = traces.T + if single_axis: + dset[:] = traces[:, 0] + else: + dset[:] = traces + else: + chunk_start = 0 + # chunk size is not None + n_chunk = num_frames // chunk_size + if num_frames % chunk_size > 0: + n_chunk += 1 + if verbose: + chunks = tqdm(range(n_chunk), ascii=True, desc="Writing to .h5 file") + else: + chunks = range(n_chunk) + for i in chunks: + traces = recording.get_traces( + segment_index=segment_index, + start_frame=i * chunk_size, + end_frame=min((i + 1) * chunk_size, num_frames), + cast_unsigned=cast_unsigned, + return_scaled=return_scaled, + ) + chunk_frames = traces.shape[0] + if dtype is not None: + traces = traces.astype(dtype_file, copy=False) + if single_axis: + dset[chunk_start : chunk_start + chunk_frames] = traces[:, 0] + else: + if time_axis == 0: + dset[chunk_start : chunk_start + chunk_frames, :] = traces + else: + dset[:, chunk_start : chunk_start + chunk_frames] = traces.T + + chunk_start += chunk_frames + + if save_path is not None: + file_handle.close() + return save_path + + +def write_traces_to_zarr( + recording, + zarr_root, + zarr_path, + storage_options, + dataset_paths, + channel_chunk_size=None, + dtype=None, + compressor=None, + filters=None, + verbose=False, + auto_cast_uint=True, + **job_kwargs, +): + """ + Save the trace of a recording extractor in several zarr format. + + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object to be saved in .dat format + zarr_root: zarr.Group + The zarr root + zarr_path: str or Path + The path to the zarr file + storage_options: dict or None + Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. + dataset_paths: list + List of paths to traces datasets in the zarr group + channel_chunk_size: int or None, default: None (chunking in time only) + Channels per chunk + dtype: dtype, default: None + Type of the saved data + compressor: zarr compressor or None, default: None + Zarr compressor + filters: list, default: None + List of zarr filters + verbose: bool, default: False + If True, output is verbose (when chunks are used) + auto_cast_uint: bool, default: True + If True, unsigned integers are automatically cast to int if the specified dtype is signed + {} + """ + assert dataset_paths is not None, "Provide 'file_path'" + + if not isinstance(dataset_paths, list): + dataset_paths = [dataset_paths] + assert len(dataset_paths) == recording.get_num_segments() + + if dtype is None: + dtype = recording.get_dtype() + if auto_cast_uint: + cast_unsigned = determine_cast_unsigned(recording, dtype) + else: + cast_unsigned = False + + job_kwargs = fix_job_kwargs(job_kwargs) + chunk_size = ensure_chunk_size(recording, **job_kwargs) + + # create zarr datasets files + for segment_index in range(recording.get_num_segments()): + num_frames = recording.get_num_samples(segment_index) + num_channels = recording.get_num_channels() + dset_name = dataset_paths[segment_index] + shape = (num_frames, num_channels) + _ = zarr_root.create_dataset( + name=dset_name, + shape=shape, + chunks=(chunk_size, channel_chunk_size), + dtype=dtype, + filters=filters, + compressor=compressor, + ) + # synchronizer=zarr.ThreadSynchronizer()) + + # use executor (loop or workers) + func = _write_zarr_chunk + init_func = _init_zarr_worker + init_args = (recording, zarr_path, storage_options, dataset_paths, dtype, cast_unsigned) + executor = ChunkRecordingExecutor( + recording, func, init_func, init_args, verbose=verbose, job_name="write_zarr_recording", **job_kwargs + ) + executor.run() + + +# used by write_zarr_recording + ChunkRecordingExecutor +def _init_zarr_worker(recording, zarr_path, storage_options, dataset_paths, dtype, cast_unsigned): + import zarr + + # create a local dict per worker + worker_ctx = {} + if isinstance(recording, dict): + from spikeinterface.core import load_extractor + + worker_ctx["recording"] = load_extractor(recording) + else: + worker_ctx["recording"] = recording + + # reload root and datasets + if storage_options is None: + if isinstance(zarr_path, str): + zarr_path_init = zarr_path + zarr_path = Path(zarr_path) + else: + zarr_path_init = str(zarr_path) + else: + zarr_path_init = zarr_path + + root = zarr.open(zarr_path_init, mode="r+", storage_options=storage_options) + zarr_datasets = [] + for dset_name in dataset_paths: + z = root[dset_name] + zarr_datasets.append(z) + worker_ctx["zarr_datasets"] = zarr_datasets + worker_ctx["dtype"] = np.dtype(dtype) + worker_ctx["cast_unsigned"] = cast_unsigned + + return worker_ctx + + +# used by write_zarr_recording + ChunkRecordingExecutor +def _write_zarr_chunk(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + recording = worker_ctx["recording"] + dtype = worker_ctx["dtype"] + zarr_dataset = worker_ctx["zarr_datasets"][segment_index] + cast_unsigned = worker_ctx["cast_unsigned"] + + # apply function + traces = recording.get_traces( + start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, cast_unsigned=cast_unsigned + ) + traces = traces.astype(dtype, copy=False) + zarr_dataset[start_frame:end_frame, :] = traces + + # fix memory leak by forcing garbage collection + del traces + gc.collect() + + +def determine_cast_unsigned(recording, dtype): + recording_dtype = np.dtype(recording.get_dtype()) + + if np.dtype(dtype) != recording_dtype and recording_dtype.kind == "u" and np.dtype(dtype).kind == "i": + cast_unsigned = True + else: + cast_unsigned = False + return cast_unsigned + def get_random_data_chunks( recording, diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 006fae5fb8..42f5c340de 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -6,8 +6,6 @@ import numpy as np from spikeinterface.core.core_tools import ( - write_binary_recording, - write_memory_recording, recursive_path_modifier, make_paths_relative, make_paths_absolute, @@ -24,147 +22,6 @@ cache_folder = Path("cache_folder") / "core" -def test_write_binary_recording(tmp_path): - # Test write_binary_recording() with loop (n_jobs=1) - # Setup - sampling_frequency = 30_000 - num_channels = 2 - dtype = "float32" - - durations = [10.0] - recording = NoiseGeneratorRecording( - durations=durations, - num_channels=num_channels, - sampling_frequency=sampling_frequency, - strategy="tile_pregenerated", - ) - file_paths = [tmp_path / "binary01.raw"] - - # Write binary recording - job_kwargs = dict(verbose=False, n_jobs=1) - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) - - # Check if written data matches original data - recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype - ) - assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) - - -def test_write_binary_recording_offset(tmp_path): - # Test write_binary_recording() with loop (n_jobs=1) - # Setup - sampling_frequency = 30_000 - num_channels = 2 - dtype = "float32" - - durations = [10.0] - recording = NoiseGeneratorRecording( - durations=durations, - num_channels=num_channels, - sampling_frequency=sampling_frequency, - strategy="tile_pregenerated", - ) - file_paths = [tmp_path / "binary01.raw"] - - # Write binary recording - job_kwargs = dict(verbose=False, n_jobs=1) - byte_offset = 125 - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, **job_kwargs) - - # Check if written data matches original data - recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, - sampling_frequency=sampling_frequency, - num_channels=num_channels, - dtype=dtype, - file_offset=byte_offset, - ) - assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) - - -def test_write_binary_recording_parallel(tmp_path): - # Test write_binary_recording() with parallel processing (n_jobs=2) - - # Setup - sampling_frequency = 30_000 - num_channels = 2 - dtype = "float32" - durations = [10.30, 3.5] - recording = NoiseGeneratorRecording( - durations=durations, - num_channels=num_channels, - sampling_frequency=sampling_frequency, - dtype=dtype, - strategy="tile_pregenerated", - ) - file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] - - # Write binary recording - job_kwargs = dict(verbose=False, n_jobs=2, chunk_memory="100k", mp_context="spawn") - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) - - # Check if written data matches original data - recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype - ) - for segment_index in range(recording.get_num_segments()): - binary_traces = recorder_binary.get_traces(segment_index=segment_index) - recording_traces = recording.get_traces(segment_index=segment_index) - assert np.allclose(binary_traces, recording_traces) - - -def test_write_binary_recording_multiple_segment(tmp_path): - # Test write_binary_recording() with multiple segments (n_jobs=2) - # Setup - sampling_frequency = 30_000 - num_channels = 10 - dtype = "float32" - - durations = [10.30, 3.5] - recording = NoiseGeneratorRecording( - durations=durations, - num_channels=num_channels, - sampling_frequency=sampling_frequency, - strategy="tile_pregenerated", - ) - file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] - - # Write binary recording - job_kwargs = dict(verbose=False, n_jobs=2, chunk_memory="100k", mp_context="spawn") - write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) - - # Check if written data matches original data - recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype - ) - - for segment_index in range(recording.get_num_segments()): - binary_traces = recorder_binary.get_traces(segment_index=segment_index) - recording_traces = recording.get_traces(segment_index=segment_index) - assert np.allclose(binary_traces, recording_traces) - - -def test_write_memory_recording(): - # 2 segments - recording = NoiseGeneratorRecording( - num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" - ) - recording = recording.save() - - # write with loop - write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1) - - write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True) - - if platform.system() != "Windows": - # write parrallel - write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory="100k") - - # write parrallel - write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, total_memory="200k", progress_bar=True) - - def test_path_utils_functions(): if platform.system() != "Windows": # posix path @@ -231,12 +88,5 @@ def test_path_utils_functions(): if __name__ == "__main__": - # Create a temporary folder using the standard library - # import tempfile - - # with tempfile.TemporaryDirectory() as tmpdirname: - # tmp_path = Path(tmpdirname) - # test_write_binary_recording(tmp_path) - # test_write_memory_recording() test_path_utils_functions() diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 1d99b192ee..8167438ce7 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -1,8 +1,16 @@ +from pathlib import Path +import platform import numpy as np from spikeinterface.core import NumpyRecording, generate_recording +from spikeinterface.core.binaryrecordingextractor import BinaryRecordingExtractor +from spikeinterface.core.generate import NoiseGeneratorRecording + + from spikeinterface.core.recording_tools import ( + write_binary_recording, + write_memory_recording, get_random_data_chunks, get_chunk_with_margin, get_closest_channels, @@ -12,6 +20,147 @@ ) +def test_write_binary_recording(tmp_path): + # Test write_binary_recording() with loop (n_jobs=1) + # Setup + sampling_frequency = 30_000 + num_channels = 2 + dtype = "float32" + + durations = [10.0] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw"] + + # Write binary recording + job_kwargs = dict(verbose=False, n_jobs=1) + write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype + ) + assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) + + +def test_write_binary_recording_offset(tmp_path): + # Test write_binary_recording() with loop (n_jobs=1) + # Setup + sampling_frequency = 30_000 + num_channels = 2 + dtype = "float32" + + durations = [10.0] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw"] + + # Write binary recording + job_kwargs = dict(verbose=False, n_jobs=1) + byte_offset = 125 + write_binary_recording(recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, + sampling_frequency=sampling_frequency, + num_channels=num_channels, + dtype=dtype, + file_offset=byte_offset, + ) + assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) + + +def test_write_binary_recording_parallel(tmp_path): + # Test write_binary_recording() with parallel processing (n_jobs=2) + + # Setup + sampling_frequency = 30_000 + num_channels = 2 + dtype = "float32" + durations = [10.30, 3.5] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + dtype=dtype, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] + + # Write binary recording + job_kwargs = dict(verbose=False, n_jobs=2, chunk_memory="100k", mp_context="spawn") + write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype + ) + for segment_index in range(recording.get_num_segments()): + binary_traces = recorder_binary.get_traces(segment_index=segment_index) + recording_traces = recording.get_traces(segment_index=segment_index) + assert np.allclose(binary_traces, recording_traces) + + +def test_write_binary_recording_multiple_segment(tmp_path): + # Test write_binary_recording() with multiple segments (n_jobs=2) + # Setup + sampling_frequency = 30_000 + num_channels = 10 + dtype = "float32" + + durations = [10.30, 3.5] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] + + # Write binary recording + job_kwargs = dict(verbose=False, n_jobs=2, chunk_memory="100k", mp_context="spawn") + write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype + ) + + for segment_index in range(recording.get_num_segments()): + binary_traces = recorder_binary.get_traces(segment_index=segment_index) + recording_traces = recording.get_traces(segment_index=segment_index) + assert np.allclose(binary_traces, recording_traces) + + +def test_write_memory_recording(): + # 2 segments + recording = NoiseGeneratorRecording( + num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" + ) + recording = recording.save() + + # write with loop + write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1) + + write_memory_recording(recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True) + + if platform.system() != "Windows": + # write parrallel + write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory="100k") + + # write parrallel + write_memory_recording(recording, dtype=None, verbose=False, n_jobs=2, total_memory="200k", progress_bar=True) + + def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) @@ -148,7 +297,17 @@ def test_order_channels_by_depth(): if __name__ == "__main__": - # test_get_random_data_chunks() - # test_get_closest_channels() - # test_get_noise_levels() + # Create a temporary folder using the standard library + import tempfile + + with tempfile.TemporaryDirectory() as tmpdirname: + tmp_path = Path(tmpdirname) + test_write_binary_recording(tmp_path) + test_write_memory_recording() + + + + test_get_random_data_chunks() + test_get_closest_channels() + test_get_noise_levels() test_order_channels_by_depth() diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index e55f6b4a53..86e3e88e65 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -10,7 +10,7 @@ from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting, BaseSortingSegment from spikeinterface.core.core_tools import define_function_from_class -from spikeinterface.core.core_tools import write_binary_recording +from spikeinterface.core import write_binary_recording from spikeinterface.core.job_tools import fix_job_kwargs diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index ccb97e31b3..66cbd36cb0 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -4,8 +4,8 @@ import probeinterface -from spikeinterface.core import BinaryRecordingExtractor, BaseRecordingSegment, BaseSorting, BaseSortingSegment -from spikeinterface.core.core_tools import write_binary_recording, define_function_from_class +from spikeinterface.core import BinaryRecordingExtractor, BaseRecordingSegment, BaseSorting, BaseSortingSegment, write_binary_recording +from spikeinterface.core.core_tools import define_function_from_class class SHYBRIDRecordingExtractor(BinaryRecordingExtractor): diff --git a/src/spikeinterface/sorters/external/hdsort.py b/src/spikeinterface/sorters/external/hdsort.py index f3a55bce91..7db2070b0c 100644 --- a/src/spikeinterface/sorters/external/hdsort.py +++ b/src/spikeinterface/sorters/external/hdsort.py @@ -6,7 +6,7 @@ import numpy as np -from spikeinterface.core.core_tools import write_to_h5_dataset_format +from spikeinterface.core import write_to_h5_dataset_format from ..basesorter import BaseSorter from ..utils import ShellScript From 2e91416e301c107227e80a4bbe0d19b9c51ce661 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Dec 2023 12:11:17 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/core_tools.py | 11 ----------- src/spikeinterface/core/recording_tools.py | 3 --- src/spikeinterface/core/tests/test_core_tools.py | 1 - src/spikeinterface/core/tests/test_recording_tools.py | 2 -- src/spikeinterface/extractors/shybridextractors.py | 8 +++++++- 5 files changed, 7 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 1fa044c804..ccf589b6e7 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -8,7 +8,6 @@ from copy import deepcopy - import numpy as np from tqdm import tqdm @@ -160,10 +159,6 @@ def add_suffix(file_path, possible_suffix): return file_path - - - - def make_shared_array(shape, dtype): from multiprocessing.shared_memory import SharedMemory @@ -176,12 +171,6 @@ def make_shared_array(shape, dtype): return arr, shm - - - - - - def is_dict_extractor(d): """ Check if a dict describe an extractor. diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 3824cb6cfe..d3aabf657c 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -21,7 +21,6 @@ ) - def read_binary_recording(file, num_channels, dtype, time_axis=0, offset=0): """ Read binary .bin or .dat file. @@ -270,7 +269,6 @@ def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx): arr[start_frame:end_frame, :] = traces - def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint=True, **job_kwargs): """ Save the traces into numpy arrays (memory). @@ -338,7 +336,6 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint= write_memory_recording.__doc__ = write_memory_recording.__doc__.format(_shared_job_kwargs_doc) - def write_to_h5_dataset_format( recording, dataset_path, diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 42f5c340de..c90ad54dfb 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -88,5 +88,4 @@ def test_path_utils_functions(): if __name__ == "__main__": - test_path_utils_functions() diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 8167438ce7..7cfc4239b6 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -305,8 +305,6 @@ def test_order_channels_by_depth(): test_write_binary_recording(tmp_path) test_write_memory_recording() - - test_get_random_data_chunks() test_get_closest_channels() test_get_noise_levels() diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index 66cbd36cb0..3cff10f8e4 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -4,7 +4,13 @@ import probeinterface -from spikeinterface.core import BinaryRecordingExtractor, BaseRecordingSegment, BaseSorting, BaseSortingSegment, write_binary_recording +from spikeinterface.core import ( + BinaryRecordingExtractor, + BaseRecordingSegment, + BaseSorting, + BaseSortingSegment, + write_binary_recording, +) from spikeinterface.core.core_tools import define_function_from_class From 7dafa9d7f46626808dd8f0a5dcb5a418e9e44e45 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Dec 2023 13:45:56 +0100 Subject: [PATCH 3/7] oups --- src/spikeinterface/core/zarrextractors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index c9540ffe46..32ab5f542a 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -10,6 +10,7 @@ from .basesorting import BaseSorting, SpikeVectorSortingSegment, minimum_spike_dtype from .core_tools import define_function_from_class, check_json from .job_tools import split_job_kwargs +from .recording_tools import determine_cast_unsigned class ZarrRecordingExtractor(BaseRecording): @@ -452,7 +453,7 @@ def add_traces_to_zarr( fix_job_kwargs, ChunkRecordingExecutor, ) - from .core_tools import determine_cast_unsigned + assert dataset_paths is not None, "Provide 'file_path'" From adefbdc30a37b0c0a8d9ab7d4b4436ae425d710c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Dec 2023 12:46:19 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/zarrextractors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 32ab5f542a..881f6ffede 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -453,7 +453,6 @@ def add_traces_to_zarr( fix_job_kwargs, ChunkRecordingExecutor, ) - assert dataset_paths is not None, "Provide 'file_path'" From 3c342d871fb759bab5aa0a891c61f56fb60c7325 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 12 Jan 2024 12:00:27 +0100 Subject: [PATCH 5/7] Update src/spikeinterface/core/__init__.py --- src/spikeinterface/core/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 4b9fedcd6f..1c19052966 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -86,7 +86,6 @@ from .recording_tools import ( write_binary_recording, write_to_h5_dataset_format, - write_binary_recording, get_random_data_chunks, get_channel_distances, get_closest_channels, From cdd09976e79c1a3f8a5924d22005ebb13b4cee4d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 12 Jan 2024 12:03:22 +0100 Subject: [PATCH 6/7] Apply suggestions from Zach Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/recording_tools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index d3aabf657c..0dc7674bc4 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -92,7 +92,7 @@ def write_binary_recording( Type of the saved data If True, file the ".raw" file extension is added if the file name is not a "raw", "bin", or "dat" byte_offset: int, default: 0 - Offset in bytes to for the binary file (e.g. to write a header) + Offset in bytes for the binary file (e.g. to write a header) auto_cast_uint: bool, default: True If True, unsigned integers are automatically cast to int if the specified dtype is signed {} @@ -288,7 +288,7 @@ def write_memory_recording(recording, dtype=None, verbose=False, auto_cast_uint= Returns --------- - arrays: one arrays per segment + arrays: one array per segment """ job_kwargs = fix_job_kwargs(job_kwargs) @@ -375,10 +375,10 @@ def write_to_h5_dataset_format( dtype: dtype, default: None Type of the saved data chunk_size: None or int, default: None - Number of chunks to save the file in. This avoid to much memory consumption for big files. + Number of chunks to save the file in. This avoids too much memory consumption for big files. If None and "chunk_memory" is given, the file is saved in chunks of "chunk_memory" MB chunk_memory: None or str, default: "500M" - Chunk size in bytes must endswith "k", "M" or "G" + Chunk size in bytes must end with "k", "M" or "G" verbose: bool, default: False If True, output is verbose (when chunks are used) auto_cast_uint: bool, default: True From 2d8f12321459e62a10c256985359efe3fa0b7e34 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 12 Jan 2024 12:32:59 +0100 Subject: [PATCH 7/7] Update src/spikeinterface/core/recording_tools.py --- src/spikeinterface/core/recording_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 0dc7674bc4..a8f12c8a1c 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -359,7 +359,7 @@ def write_to_h5_dataset_format( recording: RecordingExtractor The recording extractor object to be saved in .dat format dataset_path: str - Path to dataset in h5 file (e.g. "/dataset") + Path to dataset in the h5 file (e.g. "/dataset") segment_index: int index of segment save_path: str, default: None