Skip to content

Commit

Permalink
Merge pull request #2271 from samuelgarcia/fix_generator_in_docker
Browse files Browse the repository at this point in the history
run_sorter in container check json or pickle
  • Loading branch information
samuelgarcia authored Nov 30, 2023
2 parents 1000aae + 2f7ca19 commit 78761bc
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions src/spikeinterface/sorters/runsorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from pathlib import Path
import json
import pickle
import platform
from warnings import warn
from typing import Optional, Union
Expand Down Expand Up @@ -414,9 +415,15 @@ def run_sorter_container(

# create 3 files for communication with container
# recording dict inside
(parent_folder / "in_container_recording.json").write_text(
json.dumps(check_json(rec_dict), indent=4), encoding="utf8"
)
if recording.check_serializability("json"):
(parent_folder / "in_container_recording.json").write_text(
json.dumps(check_json(rec_dict), indent=4), encoding="utf8"
)
elif recording.check_serializability("pickle"):
(parent_folder / "in_container_recording.pickle").write_bytes(pickle.dumps(rec_dict))
else:
raise RuntimeError("To use run_sorter with container the recording must be serializable")

# need to share specific parameters
(parent_folder / "in_container_params.json").write_text(
json.dumps(check_json(sorter_params), indent=4), encoding="utf8"
Expand All @@ -433,13 +440,19 @@ def run_sorter_container(
# the py script
py_script = f"""
import json
from pathlib import Path
from spikeinterface import load_extractor
from spikeinterface.sorters import run_sorter_local
if __name__ == '__main__':
# this __name__ protection help in some case with multiprocessing (for instance HS2)
# load recording in container
recording = load_extractor('{parent_folder_unix}/in_container_recording.json')
json_rec = Path('{parent_folder_unix}/in_container_recording.json')
pickle_rec = Path('{parent_folder_unix}/in_container_recording.pickle')
if json_rec.exists():
recording = load_extractor(json_rec)
else:
recording = load_extractor(pickle_rec)
# load params in container
with open('{parent_folder_unix}/in_container_params.json', encoding='utf8', mode='r') as f:
Expand Down Expand Up @@ -593,7 +606,10 @@ def run_sorter_container(

# clean useless files
if delete_container_files:
os.remove(parent_folder / "in_container_recording.json")
if (parent_folder / "in_container_recording.json").exists():
os.remove(parent_folder / "in_container_recording.json")
if (parent_folder / "in_container_recording.pickle").exists():
os.remove(parent_folder / "in_container_recording.pickle")
os.remove(parent_folder / "in_container_params.json")
os.remove(parent_folder / "in_container_sorter_script.py")
if mode == "singularity":
Expand Down

0 comments on commit 78761bc

Please sign in to comment.