Skip to content

Commit

Permalink
Implement suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 committed Jan 3, 2025
1 parent c7ea1f5 commit 5901edc
Showing 1 changed file with 50 additions and 13 deletions.
63 changes: 50 additions & 13 deletions src/spyglass/behavior/v1/moseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,37 @@ def make(self, key):
}
)

def _make_model_name(self, key: dict = None):
def _make_model_name(self, key: dict):
# make a unique model name based on the key
if key is None:
key = {}
key = (MoseqModelSelection & key).fetch1("KEY")
return dj.hash.key_hash(key)

@staticmethod
def _initialize_model(
data, metadata, project_dir, model_name, config, model_params
data: dict,
metadata: tuple,
project_dir: str,
model_name: str,
config: dict,
model_params: dict,
):
"""Method to initialize a model. Creates model and runs initional ARHMM fit
Parameters
----------
data : dict
data dictionary (get from kpms.format_data)
metadata : tuple
metadata tuple (get from kpms.format_data)
project_dir : str
path to the project directory
model_name : str
name of the model
config : dict
keypoint moseq config
model_params : dict
params dictionary fetched from spyglass parameter table entry
Returns
-------
tuple
Expand All @@ -229,20 +247,20 @@ def _initialize_model(
model_name=model_name + "_ar",
)

def analyze_pca(self, key: dict = None):
def analyze_pca(self, key: dict, explained_variace: float = 0.9):
"""Method to analyze the PCA of a model
Parameters
----------
key : dict
key to a single MoseqModel table entry
explained_variace : float, optional
minimum explained variance to print, by default 0.9
"""
if key is None:
key = {}
project_dir = (self & key).fetch1("project_dir")
pca = kpms.load_pca(project_dir)
config = kpms.load_config(project_dir)
kpms.print_dims_to_explain_variance(pca, 0.9)
kpms.print_dims_to_explain_variance(pca, explained_variace)
kpms.plot_scree(pca, project_dir=project_dir)
kpms.plot_pcs(pca, project_dir=project_dir, **config)

Expand Down Expand Up @@ -280,8 +298,9 @@ def get_training_progress_path(self, key: dict = None):
"""
if key is None:
key = {}
project_dir = (self & key).fetch1("project_dir")
model_name = (self & key).fetch1("model_name")
project_dir, model_name = (self & key).fetch1(
"project_dir", "model_name"
)
return f"{project_dir}/{model_name}/fitting_progress.pdf"


Expand All @@ -306,7 +325,9 @@ def validate_bodyparts(self, key):
model_bodyparts = (PoseGroup & key).fetch1("bodyparts")
merge_key = {"merge_id": key["pose_merge_id"]}
bodyparts_df = (PositionOutput & merge_key).fetch_pose_dataframe()
data_bodyparts = bodyparts_df.keys().get_level_values(0).unique().values
data_bodyparts = MoseqSyllable.get_bodyparts_from_dataframe(
bodyparts_df
)

missing = [bp for bp in model_bodyparts if bp not in data_bodyparts]
if missing:
Expand Down Expand Up @@ -338,11 +359,11 @@ def make(self, key):
# load data and format for moseq
merge_query = PositionOutput & merge_key
video_path = merge_query.fetch_video_path()
video_name = Path(video_path).stem + ".mp4"
video_name = Path(video_path).name
bodyparts_df = merge_query.fetch_pose_dataframe()

if bodyparts is None:
bodyparts = bodyparts_df.keys().get_level_values(0).unique().values
bodyparts = self.get_bodyparts_from_dataframe(bodyparts_df)
datasets = {video_name: bodyparts_df[bodyparts]}
coordinates, confidences = format_dataset_for_moseq(datasets, bodyparts)
data, metadata = kpms.format_data(coordinates, confidences, **config)
Expand Down Expand Up @@ -379,3 +400,19 @@ def fetch1_dataframe(self):
dataframe = self.fetch_nwb()[0]["moseq"]
dataframe.set_index("time", inplace=True)
return dataframe

@staticmethod
def get_bodyparts_from_dataframe(dataframe):
"""Method to get the list of bodyparts from a dataframe
Parameters
----------
dataframe : pd.DataFrame
dataframe with bodypart data from PositionOutput
Returns
-------
List[str]
list of bodyparts
"""
return dataframe.keys().get_level_values(0).unique().values

0 comments on commit 5901edc

Please sign in to comment.