Skip to content

Commit

Permalink
Add loading.py and change load_extractor() to load()
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Jan 15, 2025
1 parent 40d2bdf commit ca276f8
Show file tree
Hide file tree
Showing 28 changed files with 225 additions and 207 deletions.
10 changes: 5 additions & 5 deletions src/spikeinterface/benchmark/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from spikeinterface.core import SortingAnalyzer

from spikeinterface import load_extractor, create_sorting_analyzer, load_sorting_analyzer
from spikeinterface import load, create_sorting_analyzer, load_sorting_analyzer
from spikeinterface.widgets import get_some_colors


Expand Down Expand Up @@ -150,13 +150,13 @@ def scan_folder(self):
analyzer = load_sorting_analyzer(folder)
self.analyzers[key] = analyzer
# the sorting is in memory here we take the saved one because comparisons need to pickle it later
sorting = load_extractor(analyzer.folder / "sorting")
sorting = load(analyzer.folder / "sorting")
self.datasets[key] = analyzer.recording, sorting

# for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"):
# key = rec_file.stem
# rec = load_extractor(rec_file)
# gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key)
# rec = load(rec_file)
# gt_sorting = load(self.folder / f"datasets" / "gt_sortings" / key)
# self.datasets[key] = (rec, gt_sorting)

with open(self.folder / "cases.pickle", "rb") as f:
Expand Down Expand Up @@ -428,7 +428,7 @@ def load_folder(cls, folder):
elif format == "sorting":
from spikeinterface.core import load_extractor

result[k] = load_extractor(folder / k)
result[k] = load(folder / k)
elif format == "Motion":
from spikeinterface.sortingcomponents.motion import Motion

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/comparison/multicomparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from spikeinterface.core import load_extractor, BaseSorting, BaseSortingSegment
from spikeinterface.core import load, BaseSorting, BaseSortingSegment
from spikeinterface.core.core_tools import define_function_from_class
from .basecomparison import BaseMultiComparison, MixinSpikeTrainComparison, MixinTemplateComparison
from .paircomparisons import SymmetricSortingComparison, TemplateComparison
Expand Down Expand Up @@ -230,7 +230,7 @@ def load_from_folder(folder_path):
with (folder_path / "sortings.json").open() as f:
dict_sortings = json.load(f)
name_list = list(dict_sortings.keys())
sorting_list = [load_extractor(v, base_folder=folder_path) for v in dict_sortings.values()]
sorting_list = [load(v, base_folder=folder_path) for v in dict_sortings.values()]
mcmp = MultiSortingComparison(sorting_list=sorting_list, name_list=list(name_list), do_matching=False, **kwargs)
filename = str(folder_path / "multicomparison.gpickle")
with open(filename, "rb") as f:
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/comparison/tests/test_hybrid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import shutil
from pathlib import Path
from spikeinterface.core import extract_waveforms, load_waveforms, load_extractor
from spikeinterface.core import extract_waveforms, load_waveforms, load
from spikeinterface.core.testing import check_recordings_equal
from spikeinterface.comparison import (
create_hybrid_units_recording,
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_hybrid_units_recording(setup_module):
)

# Check dumpability
saved_loaded = load_extractor(hybrid_units_recording.to_dict())
saved_loaded = load(hybrid_units_recording.to_dict())
check_recordings_equal(hybrid_units_recording, saved_loaded, return_scaled=False)

saved_1job = hybrid_units_recording.save(folder=cache_folder / "units_1job")
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_hybrid_spikes_recording(setup_module):
)

# Check dumpability
saved_loaded = load_extractor(hybrid_spikes_recording.to_dict())
saved_loaded = load(hybrid_spikes_recording.to_dict())
check_recordings_equal(hybrid_spikes_recording, saved_loaded, return_scaled=False)

saved_1job = hybrid_spikes_recording.save(folder=cache_folder / "spikes_1job")
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .base import load_extractor # , load_extractor_from_dict, load_extractor_from_json, load_extractor_from_pickle
from .baserecording import BaseRecording, BaseRecordingSegment
from .basesorting import BaseSorting, BaseSortingSegment, SpikeVectorSortingSegment
from .baseevent import BaseEvent, BaseEventSegment
from .basesnippets import BaseSnippets, BaseSnippetsSegment
from .baserecordingsnippets import BaseRecordingSnippets

from .loading import load, load_extractor

# main extractor from dump and cache
from .binaryrecordingextractor import BinaryRecordingExtractor, read_binary
from .npzsortingextractor import NpzSortingExtractor, read_npz_sorting
Expand Down
122 changes: 8 additions & 114 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def dump_to_json(
) -> None:
"""
Dump recording extractor to json file.
The extractor can be re-loaded with load_extractor(json_file)
The extractor can be re-loaded with load(json_file)
Parameters
----------
Expand Down Expand Up @@ -715,7 +715,7 @@ def dump_to_pickle(
):
"""
Dump recording extractor to a pickle file.
The extractor can be re-loaded with load_extractor(pickle_file)
The extractor can be re-loaded with load(pickle_file)
Parameters
----------
Expand Down Expand Up @@ -752,7 +752,9 @@ def dump_to_pickle(
file_path.write_bytes(pickle.dumps(dump_dict))

@staticmethod
def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, bool]] = None) -> "BaseExtractor":
def load(
file_or_folder_path: Union[str, Path], base_folder: Optional[Union[Path, str, bool]] = None
) -> "BaseExtractor":
"""
Load extractor from file path (.json or .pkl)
Expand All @@ -761,74 +763,10 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo
* save (...) a folder which contain data + json (or pickle) + metadata.
"""
error_msg = (
f"{file_path} is not a file or a folder. It should point to either a json, pickle file or a "
"folder that is the result of extractor.save(...)"
)
if not is_path_remote(file_path):
file_path = Path(file_path)

if base_folder is True:
base_folder = file_path.parent

if file_path.is_file():
# standard case based on a file (json or pickle)
if str(file_path).endswith(".json"):
with open(file_path, "r") as f:
d = json.load(f)
elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"):
with open(file_path, "rb") as f:
d = pickle.load(f)
else:
raise ValueError(error_msg)

if "warning" in d:
print("The extractor was not serializable to file")
return None

extractor = BaseExtractor.from_dict(d, base_folder=base_folder)

elif file_path.is_dir():
# case from a folder after a calling extractor.save(...)
folder = file_path
file = None

if folder.suffix == ".zarr":
from .zarrextractors import read_zarr

extractor = read_zarr(folder)
else:
# For backward compatibility (v<=0.94) we check for the cached.json/pkl/pickle files
# In later versions (v>0.94) we use the si_folder.json file
for dump_ext in ("json", "pkl", "pickle"):
f = folder / f"cached.{dump_ext}"
if f.is_file():
file = f

f = folder / f"si_folder.json"
if f.is_file():
file = f

if file is None:
raise ValueError(error_msg)
extractor = BaseExtractor.load(file, base_folder=folder)

else:
raise ValueError(error_msg)
else:
# remote case - zarr
if str(file_path).endswith(".zarr"):
from .zarrextractors import read_zarr

extractor = read_zarr(file_path)
else:
raise NotImplementedError(
"Only zarr format is supported for remote files and you should provide a path to a .zarr "
"remote path. You can save to a valid zarr folder using: "
"`extractor.save(folder='path/to/folder', format='zarr')`"
)
# use loading.py and keep backward compatibility
from .loading import load

return extractor
return load(file_or_folder_path, base_folder=base_folder)

def __reduce__(self):
"""
Expand Down Expand Up @@ -1179,50 +1117,6 @@ def _check_same_version(class_string, version):
return "unknown"


def load_extractor(file_or_folder_or_dict, base_folder=None) -> BaseExtractor:
"""
Instantiate extractor from:
* a dict
* a json file
* a pickle file
* folder (after save)
* a zarr folder (after save)
Parameters
----------
file_or_folder_or_dict : dictionary or folder or file (json, pickle)
The file path, folder path, or dictionary to load the extractor from
base_folder : str | Path | bool (optional)
The base folder to make relative paths absolute.
If True and file_or_folder_or_dict is a file, the parent folder of the file is used.
Returns
-------
extractor: Recording or Sorting
The loaded extractor object
"""
if isinstance(file_or_folder_or_dict, dict):
assert not isinstance(base_folder, bool), "`base_folder` must be a string or Path when loading from dict"
return BaseExtractor.from_dict(file_or_folder_or_dict, base_folder=base_folder)
else:
return BaseExtractor.load(file_or_folder_or_dict, base_folder=base_folder)


def load_extractor_from_dict(d, base_folder=None) -> BaseExtractor:
warnings.warn("Use load_extractor(..) instead")
return BaseExtractor.from_dict(d, base_folder=base_folder)


def load_extractor_from_json(json_file, base_folder=None) -> "BaseExtractor":
warnings.warn("Use load_extractor(..) instead")
return BaseExtractor.load(json_file, base_folder=base_folder)


def load_extractor_from_pickle(pkl_file, base_folder=None) -> "BaseExtractor":
warnings.warn("Use load_extractor(..) instead")
return BaseExtractor.load(pkl_file, base_folder=base_folder)


class BaseSegment:
def __init__(self):
self._parent_extractor = None
Expand Down
127 changes: 127 additions & 0 deletions src/spikeinterface/core/loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import warnings
from pathlib import Path


from .base import BaseExtractor
from .core_tools import is_path_remote


def load(file_or_folder_or_dict, base_folder=None) -> BaseExtractor:
"""
General load function to load a SpikeInterface object.
The function can load:
- a `Recording` or `Sorting` object from:
* dictionary
* json file
* pkl file
* binary folder (after `extractor.save(..., format='binary_folder')`)
* zarr folder (after `extractor.save(..., format='zarr')`)
* remote zarr folder
- (TODO) a `SortingAnalyzer` object from :
* binary folder
* zarr folder
* remote zarr folder
* WaveformExtractor folder
Parameters
----------
file_or_folder_or_dict : dictionary or folder or file (json, pickle)
The file path, folder path, or dictionary to load the extractor from
base_folder : str | Path | bool (optional)
The base folder to make relative paths absolute.
If True and file_or_folder_or_dict is a file, the parent folder of the file is used.
Returns
-------
extractor: Recording or Sorting
The loaded extractor object
"""
if isinstance(file_or_folder_or_dict, dict):
assert not isinstance(base_folder, bool), "`base_folder` must be a string or Path when loading from dict"
return BaseExtractor.from_dict(file_or_folder_or_dict, base_folder=base_folder)
else:
file_path = file_or_folder_or_dict
error_msg = (
f"{file_path} is not a file or a folder. It should point to either a json, pickle file or a "
"folder that is the result of extractor.save(...)"
)
if not is_path_remote(file_path):
file_path = Path(file_path)

if base_folder is True:
base_folder = file_path.parent

if file_path.is_file():
# standard case based on a file (json or pickle)
if str(file_path).endswith(".json"):
import json

with open(file_path, "r") as f:
d = json.load(f)
elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"):
import pickle

with open(file_path, "rb") as f:
d = pickle.load(f)
else:
raise ValueError(error_msg)

# this is for back-compatibility since now unserializable objects will not
# be saved to file
if "warning" in d:
print("The extractor was not serializable to file")
return None

extractor = BaseExtractor.from_dict(d, base_folder=base_folder)

elif file_path.is_dir():
# this can be and extractor, SortingAnalyzer, or WaveformExtractor
folder = file_path
file = None

if folder.suffix == ".zarr":
from .zarrextractors import read_zarr

extractor = read_zarr(folder)
else:
# For backward compatibility (v<=0.94) we check for the cached.json/pkl/pickle files
# In later versions (v>0.94) we use the si_folder.json file
for dump_ext in ("json", "pkl", "pickle"):
f = folder / f"cached.{dump_ext}"
if f.is_file():
file = f

f = folder / f"si_folder.json"
if f.is_file():
file = f

if file is None:
raise ValueError(error_msg)
extractor = BaseExtractor.load(file, base_folder=folder)

else:
raise ValueError(error_msg)
else:
# remote case - zarr
if str(file_path).endswith(".zarr"):
from .zarrextractors import read_zarr

extractor = read_zarr(file_path)
else:
raise NotImplementedError(
"Only zarr format is supported for remote files and you should provide a path to a .zarr "
"remote path. You can save to a valid zarr folder using: "
"`extractor.save(folder='path/to/folder', format='zarr')`"
)

return extractor


def load_extractor(file_or_folder_or_dict, base_folder=None) -> BaseExtractor:
warnings.warn(
"load_extractor() is deprecated and will be removed in the future. Please use load() instead.",
DeprecationWarning,
stacklevel=2,
)
return load(file_or_folder_or_dict, base_folder=base_folder)
4 changes: 2 additions & 2 deletions src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ def _init_memory_worker(recording, arrays, shm_names, shapes, dtype, cast_unsign
# create a local dict per worker
worker_ctx = {}
if isinstance(recording, dict):
from spikeinterface.core import load_extractor
from spikeinterface.core import load

worker_ctx["recording"] = load_extractor(recording)
worker_ctx["recording"] = load(recording)
else:
worker_ctx["recording"] = recording

Expand Down
Loading

0 comments on commit ca276f8

Please sign in to comment.