Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 30, 2025
1 parent ccdff24 commit 7aa439f
Showing 1 changed file with 31 additions and 21 deletions.
52 changes: 31 additions & 21 deletions src/spikeinterface/core/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from .core_tools import is_path_remote




_error_msg = (
"{file_path} is not a file or a folder. It should point to either a json, pickle file or a "
"folder that is the result of extractor.save(...) or sortinganalyzer.save_as(...)"
)


def load(
file_or_folder_or_dict, base_folder=None, **kwargs
file_or_folder_or_dict,
base_folder=None,
**kwargs,
# load_extensions=True, backend_options=None
) -> "BaseExtractor | SortingAnalyzer | Motion | Template":
"""
Expand Down Expand Up @@ -76,38 +77,40 @@ def load(
is_local = not is_path_remote(file_or_folder_or_dict)
if is_local:
file_path = Path(file_or_folder_or_dict)

if is_local and file_path.is_file():
# Standard case based on a file (json or pickle) after a Base.dump(json/pickle)
if base_folder is True:
base_folder = file_path.parent

if str(file_path).endswith(".json"):
import json

with open(file_path, "r") as f:
d = json.load(f)
elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"):
import pickle

with open(file_path, "rb") as f:
d = pickle.load(f)
else:
raise ValueError(_error_msg.format(file_path=file_path))

object_type = _guess_object_from_dict(d)
if object_type is None:
raise ValueError(_error_msg.format(file_path=file_path))
return _load_object_from_dict(d, object_type, base_folder=base_folder)

elif is_local and file_path.is_dir():

folder = file_path
if folder.suffix == ".zarr":
# Local zarr can be
# Sorting Recording SortingAnalyzer Template
object_type = _guess_object_from_zarr(folder)
if object_type is None:
raise ValueError(_error_msg.format(file_path=file_path))
loaded_object = _load_object_from_zarr(folder, object_type, **kwargs)
loaded_object = _load_object_from_zarr(folder, object_type, **kwargs)
return loaded_object

else:
Expand All @@ -116,7 +119,7 @@ def load(
object_type = _guess_object_from_local_folder(folder)
if object_type is None:
raise ValueError(_error_msg.format(file_path=file_path))
loaded_object = _load_object_from_folder(folder, object_type, **kwargs)
loaded_object = _load_object_from_folder(folder, object_type, **kwargs)
return loaded_object
else:
# remote zarr can be
Expand All @@ -125,7 +128,6 @@ def load(
object_type = _guess_object_from_zarr(url)
loaded_object = _load_object_from_zarr(url, object_type, **kwargs)
return loaded_object



def load_extractor(file_or_folder_or_dict, base_folder=None) -> "BaseExtractor":
Expand All @@ -147,7 +149,7 @@ def _guess_object_from_dict(d):
# the case is explicit.
# SortingAnalyzer and Motion used to implement this from the start of implementing
return d["object"]

# Template
is_template = True
for k in ("templates_array", "sparsity_mask", "channel_ids", "unit_ids"):
Expand All @@ -165,24 +167,24 @@ def _guess_object_from_dict(d):
break
if is_sorting_or_recording:
return "Recording|Sorting"

# Unknow
return None


def _load_object_from_dict(d, object_type, base_folder=None):
if object_type in ("Recording", "Sorting", "Recording|Sorting"):
return BaseExtractor.from_dict(d, base_folder=base_folder)

elif object_type == "Templates":
from spikeinterface.core import Templates

return Templates.from_dict(d)

# elif object_type == "Motion":
# TODO to be implemented in Motion.from_dict




def _guess_object_from_local_folder(folder):
folder = Path(folder)

Expand Down Expand Up @@ -213,16 +215,19 @@ def _guess_object_from_local_folder(folder):
def _load_object_from_folder(folder, object_type, **kwargs):
if object_type == "SortingAnalyzer":
from .sortinganalyzer import load_sorting_analyzer

analyzer = load_sorting_analyzer(folder, **kwargs)
return analyzer

elif object_type == "Motion":
from spikeinterface.sortingcomponents.motion import Motion

motion = Motion.load(folder)
return motion

elif object_type == "Waveforms":
from .waveforms_extractor_backwards_compatibility import load_waveforms

analyzer = load_waveforms(folder, output="SortingAnalyzer")
return analyzer

Expand All @@ -243,31 +248,33 @@ def _load_object_from_folder(folder, object_type, **kwargs):
def _guess_object_from_zarr(zarr_folder):
# here it can be a zarr folder for Recording|Sorting|SortingAnalyzer|Template
from .zarrextractors import super_zarr_open

zarr_root = super_zarr_open(zarr_folder, mode="r")

# can be SortingAnalyzer
spikeinterface_info = zarr_root.attrs.get("spikeinterface_info")
if spikeinterface_info is not None:
return _guess_object_from_dict(spikeinterface_info)

# here it is the old fashion and a bit ambiguous
if "channel_ids" in zarr_root.keys() and "unit_ids" in zarr_root.keys() and "nbefore" in zarr_root.keys() :
if "channel_ids" in zarr_root.keys() and "unit_ids" in zarr_root.keys() and "nbefore" in zarr_root.keys():
return "Templates"
elif "channel_ids" in zarr_root.keys() and "unit_ids" not in zarr_root.keys():
return "Recording"
elif "unit_ids" in zarr_root.keys() and "channel_ids" not in zarr_root.keys():
return "Sorting"


def _load_object_from_zarr(folder_or_url, object_type, **kwargs):
def _load_object_from_zarr(folder_or_url, object_type, **kwargs):

storage_options = kwargs.get("storage_options", None)

if object_type == "SortingAnalyzer":
from .sortinganalyzer import load_sorting_analyzer

analyzer = load_sorting_analyzer(folder_or_url, **kwargs)
return analyzer

# elif object_type == "Motion":
# No Motion in zarr

Expand All @@ -276,15 +283,18 @@ def _load_object_from_zarr(folder_or_url, object_type, **kwargs):

elif object_type == "Recording":
from .zarrextractors import read_zarr_recording

recording = read_zarr_recording(folder_or_url, storage_options=storage_options)
return recording
elif object_type == "Sorting":
from .zarrextractors import read_zarr_sorting

sorting = read_zarr_sorting(folder_or_url, storage_options=storage_options)
return sorting
elif object_type == "Recording|Sorting":
# This case shoudl deprecated soon because the read_zarr is ultra ambiguous
# just testing if the zarrot contains unit_ids or channel_ids but many object also contains it (see template)!!!!
from .zarrextractors import read_zarr

rec_or_sorting = read_zarr(folder_or_url, storage_options=storage_options)
return rec_or_sorting

0 comments on commit 7aa439f

Please sign in to comment.