Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Broz <[email protected]>
  • Loading branch information
samuelbray32 and CBroz1 authored Jan 2, 2025
1 parent 7d30fc1 commit 81f3dfa
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
8 changes: 3 additions & 5 deletions src/spyglass/behavior/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def create_group(
body parts to include in the group, by default None includes all from every set
"""
group_key = {
"pose_group_name": group_name,
"pose_group_name": group_name
}
if self & group_key:
warnings.warn(
Expand All @@ -55,16 +55,14 @@ def create_group(
{
**group_key,
"bodyparts": bodyparts,
},
skip_duplicates=True,
}
)
for merge_id in merge_ids:
self.Pose.insert1(
{
**group_key,
"pose_merge_id": merge_id,
},
skip_duplicates=True,
}
)

def fetch_pose_datasets(
Expand Down
28 changes: 14 additions & 14 deletions src/spyglass/behavior/moseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


@schema
class MoseqModelParams(SpyglassMixin, dj.Manual):
class MoseqModelParams(SpyglassMixin, dj.Lookup):
definition = """
model_params_name: varchar(80)
---
Expand Down Expand Up @@ -49,8 +49,7 @@ def make_training_extension_params(
"""
model_key = (MoseqModel & model_key).fetch1("KEY")
model_params = (self & model_key).fetch1("model_params")
model_params["num_epochs"] = num_epochs
model_params["initial_model"] = model_key
model_params.update({"num_epochs":num_epochs, "initial_model":model_key})
# increment param name
if new_name is None:
# increment the extension number
Expand All @@ -76,6 +75,7 @@ def make_training_extension_params(

@schema
class MoseqModelSelection(SpyglassMixin, dj.Manual):
"""Pairing of PoseGroup and moseq model params for training"""
definition = """
-> PoseGroup
-> MoseqModelParams
Expand Down Expand Up @@ -257,8 +257,7 @@ def fetch_model(self, key: dict = None):
if key is None:
key = {}
return kpms.load_checkpoint(
(self & key).fetch1("project_dir"),
(self & key).fetch1("model_name"),
*(self & key).fetch1("project_dir", "model_name")
)[0]

def get_training_progress_path(self, key: dict = None):
Expand Down Expand Up @@ -303,11 +302,12 @@ def validate_bodyparts(self, key):
merge_key = {"merge_id": key["pose_merge_id"]}
bodyparts_df = (PositionOutput & merge_key).fetch_pose_dataframe()
data_bodyparts = bodyparts_df.keys().get_level_values(0).unique().values
for bodypart in model_bodyparts:
if bodypart not in data_bodyparts:
raise ValueError(
f"Error in row {key}: " + f"Bodypart {bodypart} not in data"
)

missing = [bp for bp in model_bodyparts if bp not in data_bodyparts]
if missing:
raise ValueError(
f"PositionOutput missing bodypart(s) for {key}: {missing}"
)


@schema
Expand All @@ -321,18 +321,18 @@ class MoseqSyllable(SpyglassMixin, dj.Computed):

def make(self, key):
model = MoseqModel().fetch_model(key)
project_dir = (MoseqModel & key).fetch1("project_dir")
project_dir, model_name = (MoseqModel & key).fetch1("project_dir", "model_name")

merge_key = {"merge_id": key["pose_merge_id"]}
bodyparts = (PoseGroup & key).fetch1("bodyparts")
config = MoseqModel()._config_func(project_dir)
model_name = (MoseqModel & key).fetch1("model_name")
num_iters = (MoseqSyllableSelection & key).fetch1("num_iters")

# load data and format for moseq
video_path = (PositionOutput & merge_key).fetch_video_path()
merge_query = PositionOutput & merge_key
video_path = merge_query.fetch_video_path()
video_name = Path(video_path).stem + ".mp4"
bodyparts_df = (PositionOutput & merge_key).fetch_pose_dataframe()
bodyparts_df = merge_query.fetch_pose_dataframe()

if bodyparts is None:
bodyparts = bodyparts_df.keys().get_level_values(0).unique().values
Expand Down

0 comments on commit 81f3dfa

Please sign in to comment.