diff --git a/src/spyglass/behavior/moseq.py b/src/spyglass/behavior/moseq.py index 181ed4cac..85e0573de 100644 --- a/src/spyglass/behavior/moseq.py +++ b/src/spyglass/behavior/moseq.py @@ -6,6 +6,7 @@ from spyglass.common import AnalysisNwbfile from spyglass.position.position_merge import PositionOutput +from spyglass.settings import moseq_project_dir, moseq_video_dir from spyglass.utils import SpyglassMixin from .core import PoseGroup, format_dataset_for_moseq, results_to_df @@ -113,15 +114,8 @@ def make(self, key): model_params = (MoseqModelParams & key).fetch1("model_params") # set up the project and config - project_dir = ( - "/home/sambray/Documents/moseq_test_proj3" # TODO: make this better - ) - video_dir = ( - "/home/sambray/Documents/moseq_test_vids3" # TODO: make this better - ) + project_dir, video_dir = moseq_project_dir, moseq_video_dir # make symlinks to the videos in a single directory - os.makedirs(video_dir, exist_ok=True) - # os.makedirs(project_dir, exist_ok=True) video_paths = (PoseGroup & key).fetch_video_paths() for video in video_paths: destination = os.path.join(video_dir, os.path.basename(video)) diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index 15bcf9d9c..068607386 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -82,6 +82,10 @@ def __init__(self, base_dir: str = None, **kwargs) -> None: "video": "video", "output": "output", }, + "moseq": { + "project": "projects", + "video": "video", + }, } self.dj_defaults = { "database.host": kwargs.get("database_host", "lmf-db.cin.ucsf.edu"), @@ -139,6 +143,7 @@ def load_config( dj_spyglass = dj_custom.get("spyglass_dirs", {}) dj_kachery = dj_custom.get("kachery_dirs", {}) dj_dlc = dj_custom.get("dlc_dirs", {}) + dj_moseq = dj_custom.get("moseq_dirs", {}) self._debug_mode = dj_custom.get("debug_mode", False) self._test_mode = kwargs.get("test_mode") or dj_custom.get( @@ -174,9 +179,20 @@ def load_config( ) Path(self._dlc_base).mkdir(exist_ok=True) + self._moseq_base = ( + dj_moseq.get("base") + or os.environ.get("MOSEQ_BASE_DIR") + or str(Path(resolved_base) / "moseq") + ) + Path(self._moseq_base).mkdir(exist_ok=True) + config_dirs = {"SPYGLASS_BASE_DIR": str(resolved_base)} for prefix, dirs in self.relative_dirs.items(): - this_base = self._dlc_base if prefix == "dlc" else resolved_base + this_base = ( + self._dlc_base + if prefix == "dlc" + else (self._moseq_base if prefix == "moseq" else resolved_base) + ) for dir, dir_str in dirs.items(): dir_env_fmt = self.dir_to_var(dir=dir, dir_type=prefix) @@ -185,12 +201,14 @@ def load_config( if not self.supplied_base_dir else None ) - - source_config = ( - dj_dlc - if prefix == "dlc" - else dj_kachery if prefix == "kachery" else dj_spyglass - ) + if prefix == "dlc": + source_config = dj_dlc + elif prefix == "moseq": + source_config = dj_moseq + elif prefix == "kachery": + source_config = dj_kachery + else: + source_config = dj_spyglass dir_location = ( source_config.get(dir) or env_loc @@ -482,6 +500,11 @@ def _dj_custom(self) -> dict: "video": self.dlc_video_dir, "output": self.dlc_output_dir, }, + "moseq_dirs": { + "base": self._moseq_base, + "project": self.moseq_project_dir, + "video": self.moseq_video_dir, + }, "kachery_zone": "franklab.default", } } @@ -567,6 +590,16 @@ def dlc_output_dir(self) -> str: """DLC output directory as a string.""" return self.config.get(self.dir_to_var("output", "dlc")) + @property + def moseq_project_dir(self) -> str: + """Moseq project directory as a string.""" + return self.config.get(self.dir_to_var("project", "moseq")) + + @property + def moseq_video_dir(self) -> str: + """Moseq video directory as a string.""" + return self.config.get(self.dir_to_var("video", "moseq")) + sg_config = SpyglassConfig() sg_config.load_config(on_startup=True) @@ -588,6 +621,8 @@ def dlc_output_dir(self) -> str: dlc_project_dir = None dlc_video_dir = None dlc_output_dir = None + moseq_project_dir = None + moseq_video_dir = None else: config = sg_config.config base_dir = sg_config.base_dir @@ -605,3 +640,5 @@ def dlc_output_dir(self) -> str: dlc_project_dir = sg_config.dlc_project_dir dlc_video_dir = sg_config.dlc_video_dir dlc_output_dir = sg_config.dlc_output_dir + moseq_project_dir = sg_config.moseq_project_dir + moseq_video_dir = sg_config.moseq_video_dir