diff --git a/src/spyglass/behavior/core.py b/src/spyglass/behavior/core.py index 79b804c6c..0ab674122 100644 --- a/src/spyglass/behavior/core.py +++ b/src/spyglass/behavior/core.py @@ -44,7 +44,7 @@ def create_group( body parts to include in the group, by default None includes all from every set """ group_key = { - "pose_group_name": group_name, + "pose_group_name": group_name } if self & group_key: warnings.warn( @@ -55,16 +55,14 @@ def create_group( { **group_key, "bodyparts": bodyparts, - }, - skip_duplicates=True, + } ) for merge_id in merge_ids: self.Pose.insert1( { **group_key, "pose_merge_id": merge_id, - }, - skip_duplicates=True, + } ) def fetch_pose_datasets( diff --git a/src/spyglass/behavior/moseq.py b/src/spyglass/behavior/moseq.py index b5f2132f7..9c5f3e6bb 100644 --- a/src/spyglass/behavior/moseq.py +++ b/src/spyglass/behavior/moseq.py @@ -15,7 +15,7 @@ @schema -class MoseqModelParams(SpyglassMixin, dj.Manual): +class MoseqModelParams(SpyglassMixin, dj.Lookup): definition = """ model_params_name: varchar(80) --- @@ -49,8 +49,7 @@ def make_training_extension_params( """ model_key = (MoseqModel & model_key).fetch1("KEY") model_params = (self & model_key).fetch1("model_params") - model_params["num_epochs"] = num_epochs - model_params["initial_model"] = model_key + model_params.update({"num_epochs":num_epochs, "initial_model":model_key}) # increment param name if new_name is None: # increment the extension number @@ -76,6 +75,7 @@ def make_training_extension_params( @schema class MoseqModelSelection(SpyglassMixin, dj.Manual): + """Pairing of PoseGroup and moseq model params for training""" definition = """ -> PoseGroup -> MoseqModelParams @@ -257,8 +257,7 @@ def fetch_model(self, key: dict = None): if key is None: key = {} return kpms.load_checkpoint( - (self & key).fetch1("project_dir"), - (self & key).fetch1("model_name"), + *(self & key).fetch1("project_dir", "model_name") )[0] def get_training_progress_path(self, key: dict = None): @@ -303,11 +302,12 @@ def validate_bodyparts(self, key): merge_key = {"merge_id": key["pose_merge_id"]} bodyparts_df = (PositionOutput & merge_key).fetch_pose_dataframe() data_bodyparts = bodyparts_df.keys().get_level_values(0).unique().values - for bodypart in model_bodyparts: - if bodypart not in data_bodyparts: - raise ValueError( - f"Error in row {key}: " + f"Bodypart {bodypart} not in data" - ) + + missing = [bp for bp in model_bodyparts if bp not in data_bodyparts] + if missing: + raise ValueError( + f"PositionOutput missing bodypart(s) for {key}: {missing}" + ) @schema @@ -321,18 +321,18 @@ class MoseqSyllable(SpyglassMixin, dj.Computed): def make(self, key): model = MoseqModel().fetch_model(key) - project_dir = (MoseqModel & key).fetch1("project_dir") + project_dir, model_name = (MoseqModel & key).fetch1("project_dir", "model_name") merge_key = {"merge_id": key["pose_merge_id"]} bodyparts = (PoseGroup & key).fetch1("bodyparts") config = MoseqModel()._config_func(project_dir) - model_name = (MoseqModel & key).fetch1("model_name") num_iters = (MoseqSyllableSelection & key).fetch1("num_iters") # load data and format for moseq - video_path = (PositionOutput & merge_key).fetch_video_path() + merge_query = PositionOutput & merge_key + video_path = merge_query.fetch_video_path() video_name = Path(video_path).stem + ".mp4" - bodyparts_df = (PositionOutput & merge_key).fetch_pose_dataframe() + bodyparts_df = merge_query.fetch_pose_dataframe() if bodyparts is None: bodyparts = bodyparts_df.keys().get_level_values(0).unique().values