Skip to content

Commit

Permalink
Load extractor can read remote zarr use_times in get_duration
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Jan 13, 2025
1 parent 3c8e96f commit e78f44b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 53 deletions.
102 changes: 55 additions & 47 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .globals import get_global_tmp_folder, is_set_global_tmp_folder
from .core_tools import (
check_json,
is_path_remote,
clean_zarr_folder_name,
is_dict_extractor,
SIJsonEncoder,
Expand Down Expand Up @@ -761,63 +761,71 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo
* save (...) a folder which contain data + json (or pickle) + metadata.
"""
if not is_path_remote(file_path):
file_path = Path(file_path)

if base_folder is True:
base_folder = file_path.parent

if file_path.is_file():
# standard case based on a file (json or pickle)
if str(file_path).endswith(".json"):
with open(file_path, "r") as f:
d = json.load(f)
elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"):
with open(file_path, "rb") as f:
d = pickle.load(f)
else:
raise ValueError(f"Impossible to load {file_path}")
if "warning" in d:
print("The extractor was not serializable to file")
return None

file_path = Path(file_path)
if base_folder is True:
base_folder = file_path.parent

if file_path.is_file():
# standard case based on a file (json or pickle)
if str(file_path).endswith(".json"):
with open(file_path, "r") as f:
d = json.load(f)
elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"):
with open(file_path, "rb") as f:
d = pickle.load(f)
else:
raise ValueError(f"Impossible to load {file_path}")
if "warning" in d:
print("The extractor was not serializable to file")
return None
extractor = BaseExtractor.from_dict(d, base_folder=base_folder)
return extractor

extractor = BaseExtractor.from_dict(d, base_folder=base_folder)
return extractor
elif file_path.is_dir():
# case from a folder after a calling extractor.save(...)
folder = file_path
file = None

elif file_path.is_dir():
# case from a folder after a calling extractor.save(...)
folder = file_path
file = None
if folder.suffix == ".zarr":
from .zarrextractors import read_zarr

if folder.suffix == ".zarr":
from .zarrextractors import read_zarr

extractor = read_zarr(folder)
else:
# the is spikeinterface<=0.94.0
# a folder came with 'cached.json'
for dump_ext in ("json", "pkl", "pickle"):
f = folder / f"cached.{dump_ext}"
extractor = read_zarr(folder)
else:
# the is spikeinterface<=0.94.0
# a folder came with 'cached.json'
for dump_ext in ("json", "pkl", "pickle"):
f = folder / f"cached.{dump_ext}"
if f.is_file():
file = f

# spikeinterface>=0.95.0
f = folder / f"si_folder.json"
if f.is_file():
file = f

# spikeinterface>=0.95.0
f = folder / f"si_folder.json"
if f.is_file():
file = f
if file is None:
raise ValueError(f"This folder is not a cached folder {file_path}")
extractor = BaseExtractor.load(file, base_folder=folder)
else:
error_msg = (
f"{file_path} is not a file or a folder. It should point to either a json, pickle file or a "
"folder that is the result of extractor.save(...)"
)
raise ValueError(error_msg)
else:
# remote case - zarr
if str(file_path).endswith(".zarr"):
from .zarrextractors import read_zarr

if file is None:
raise ValueError(f"This folder is not a cached folder {file_path}")
extractor = BaseExtractor.load(file, base_folder=folder)
extractor = read_zarr(file_path)
else:
raise NotImplementedError("Only zarr format is supported for remote files")

return extractor

else:
error_msg = (
f"{file_path} is not a file or a folder. It should point to either a json, pickle file or a "
"folder that is the result of extractor.save(...)"
)
raise ValueError(error_msg)

def __reduce__(self):
"""
This function is used by pickle to serialize the object.
Expand Down
21 changes: 15 additions & 6 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __repr__(self):
if num_segments > 1:
samples_per_segment = [self.get_num_samples(segment_index) for segment_index in range(num_segments)]
memory_per_segment_bytes = (self.get_memory_size(segment_index) for segment_index in range(num_segments))
durations = [self.get_duration(segment_index) for segment_index in range(num_segments)]
durations = [self.get_duration(segment_index, use_times=False) for segment_index in range(num_segments)]

samples_per_segment_formated = [f"{samples:,}" for samples in samples_per_segment]
durations_per_segment_formated = [convert_seconds_to_str(d) for d in durations]
Expand Down Expand Up @@ -95,7 +95,7 @@ def _repr_header(self):
dtype = self.get_dtype()

total_samples = self.get_total_samples()
total_duration = self.get_total_duration()
total_duration = self.get_total_duration(use_times=False)
total_memory_size = self.get_total_memory_size()

sf_hz = self.get_sampling_frequency()
Expand Down Expand Up @@ -216,7 +216,7 @@ def get_total_samples(self) -> int:

return sum(samples_per_segment)

def get_duration(self, segment_index=None) -> float:
def get_duration(self, segment_index=None, use_times=True) -> float:
"""
Returns the duration in seconds.
Expand All @@ -226,6 +226,9 @@ def get_duration(self, segment_index=None) -> float:
The sample index to retrieve the duration for.
For multi-segment objects, it is required, default: None
With single segment recording returns the duration of the single segment
use_times : bool, default: True
If True, the duration is calculated using the time vector if available.
If False, the duration is calculated using the number of samples and the sampling frequency.
Returns
-------
Expand All @@ -234,7 +237,7 @@ def get_duration(self, segment_index=None) -> float:
"""
segment_index = self._check_segment_index(segment_index)

if self.has_time_vector(segment_index):
if self.has_time_vector(segment_index) and use_times:
times = self.get_times(segment_index)
segment_duration = times[-1] - times[0] + (1 / self.get_sampling_frequency())
else:
Expand All @@ -243,16 +246,22 @@ def get_duration(self, segment_index=None) -> float:

return segment_duration

def get_total_duration(self) -> float:
def get_total_duration(self, use_times=True) -> float:
"""
Returns the total duration in seconds
Parameters
----------
use_times : bool, default: True
If True, the duration is calculated using the time vector if available.
If False, the duration is calculated using the number of samples and the sampling frequency.
Returns
-------
float
The duration in seconds
"""
duration = sum([self.get_duration(idx) for idx in range(self.get_num_segments())])
duration = sum([self.get_duration(idx, use_times) for idx in range(self.get_num_segments())])
return duration

def get_memory_size(self, segment_index=None) -> int:
Expand Down

0 comments on commit e78f44b

Please sign in to comment.