From e78f44b17d36de9ee67db4eee03d08f473c0f821 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 13 Jan 2025 16:06:59 +0100 Subject: [PATCH] Load extractor can read remote zarr use_times in get_duration --- src/spikeinterface/core/base.py | 102 ++++++++++++----------- src/spikeinterface/core/baserecording.py | 21 +++-- 2 files changed, 70 insertions(+), 53 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 1fa218851b..64eb6f6ca3 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -16,7 +16,7 @@ from .globals import get_global_tmp_folder, is_set_global_tmp_folder from .core_tools import ( - check_json, + is_path_remote, clean_zarr_folder_name, is_dict_extractor, SIJsonEncoder, @@ -761,63 +761,71 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo * save (...) a folder which contain data + json (or pickle) + metadata. """ + if not is_path_remote(file_path): + file_path = Path(file_path) + + if base_folder is True: + base_folder = file_path.parent + + if file_path.is_file(): + # standard case based on a file (json or pickle) + if str(file_path).endswith(".json"): + with open(file_path, "r") as f: + d = json.load(f) + elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): + with open(file_path, "rb") as f: + d = pickle.load(f) + else: + raise ValueError(f"Impossible to load {file_path}") + if "warning" in d: + print("The extractor was not serializable to file") + return None - file_path = Path(file_path) - if base_folder is True: - base_folder = file_path.parent - - if file_path.is_file(): - # standard case based on a file (json or pickle) - if str(file_path).endswith(".json"): - with open(file_path, "r") as f: - d = json.load(f) - elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - with open(file_path, "rb") as f: - d = pickle.load(f) - else: - raise ValueError(f"Impossible to load {file_path}") - if "warning" in d: - print("The extractor was not serializable to file") - return None + extractor = BaseExtractor.from_dict(d, base_folder=base_folder) + return extractor - extractor = BaseExtractor.from_dict(d, base_folder=base_folder) - return extractor + elif file_path.is_dir(): + # case from a folder after a calling extractor.save(...) + folder = file_path + file = None - elif file_path.is_dir(): - # case from a folder after a calling extractor.save(...) - folder = file_path - file = None + if folder.suffix == ".zarr": + from .zarrextractors import read_zarr - if folder.suffix == ".zarr": - from .zarrextractors import read_zarr - - extractor = read_zarr(folder) - else: - # the is spikeinterface<=0.94.0 - # a folder came with 'cached.json' - for dump_ext in ("json", "pkl", "pickle"): - f = folder / f"cached.{dump_ext}" + extractor = read_zarr(folder) + else: + # the is spikeinterface<=0.94.0 + # a folder came with 'cached.json' + for dump_ext in ("json", "pkl", "pickle"): + f = folder / f"cached.{dump_ext}" + if f.is_file(): + file = f + + # spikeinterface>=0.95.0 + f = folder / f"si_folder.json" if f.is_file(): file = f - # spikeinterface>=0.95.0 - f = folder / f"si_folder.json" - if f.is_file(): - file = f + if file is None: + raise ValueError(f"This folder is not a cached folder {file_path}") + extractor = BaseExtractor.load(file, base_folder=folder) + else: + error_msg = ( + f"{file_path} is not a file or a folder. It should point to either a json, pickle file or a " + "folder that is the result of extractor.save(...)" + ) + raise ValueError(error_msg) + else: + # remote case - zarr + if str(file_path).endswith(".zarr"): + from .zarrextractors import read_zarr - if file is None: - raise ValueError(f"This folder is not a cached folder {file_path}") - extractor = BaseExtractor.load(file, base_folder=folder) + extractor = read_zarr(file_path) + else: + raise NotImplementedError("Only zarr format is supported for remote files") return extractor - else: - error_msg = ( - f"{file_path} is not a file or a folder. It should point to either a json, pickle file or a " - "folder that is the result of extractor.save(...)" - ) - raise ValueError(error_msg) - def __reduce__(self): """ This function is used by pickle to serialize the object. diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 7ca527e255..b42bc3a273 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -59,7 +59,7 @@ def __repr__(self): if num_segments > 1: samples_per_segment = [self.get_num_samples(segment_index) for segment_index in range(num_segments)] memory_per_segment_bytes = (self.get_memory_size(segment_index) for segment_index in range(num_segments)) - durations = [self.get_duration(segment_index) for segment_index in range(num_segments)] + durations = [self.get_duration(segment_index, use_times=False) for segment_index in range(num_segments)] samples_per_segment_formated = [f"{samples:,}" for samples in samples_per_segment] durations_per_segment_formated = [convert_seconds_to_str(d) for d in durations] @@ -95,7 +95,7 @@ def _repr_header(self): dtype = self.get_dtype() total_samples = self.get_total_samples() - total_duration = self.get_total_duration() + total_duration = self.get_total_duration(use_times=False) total_memory_size = self.get_total_memory_size() sf_hz = self.get_sampling_frequency() @@ -216,7 +216,7 @@ def get_total_samples(self) -> int: return sum(samples_per_segment) - def get_duration(self, segment_index=None) -> float: + def get_duration(self, segment_index=None, use_times=True) -> float: """ Returns the duration in seconds. @@ -226,6 +226,9 @@ def get_duration(self, segment_index=None) -> float: The sample index to retrieve the duration for. For multi-segment objects, it is required, default: None With single segment recording returns the duration of the single segment + use_times : bool, default: True + If True, the duration is calculated using the time vector if available. + If False, the duration is calculated using the number of samples and the sampling frequency. Returns ------- @@ -234,7 +237,7 @@ def get_duration(self, segment_index=None) -> float: """ segment_index = self._check_segment_index(segment_index) - if self.has_time_vector(segment_index): + if self.has_time_vector(segment_index) and use_times: times = self.get_times(segment_index) segment_duration = times[-1] - times[0] + (1 / self.get_sampling_frequency()) else: @@ -243,16 +246,22 @@ def get_duration(self, segment_index=None) -> float: return segment_duration - def get_total_duration(self) -> float: + def get_total_duration(self, use_times=True) -> float: """ Returns the total duration in seconds + Parameters + ---------- + use_times : bool, default: True + If True, the duration is calculated using the time vector if available. + If False, the duration is calculated using the number of samples and the sampling frequency. + Returns ------- float The duration in seconds """ - duration = sum([self.get_duration(idx) for idx in range(self.get_num_segments())]) + duration = sum([self.get_duration(idx, use_times) for idx in range(self.get_num_segments())]) return duration def get_memory_size(self, segment_index=None) -> int: