Skip to content

Commit

Permalink
cleanup config method
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 committed Jan 3, 2025
1 parent bd29c76 commit c7ea1f5
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions src/spyglass/behavior/v1/moseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,13 @@ def make(self, key):
posterior_bodyparts=model_params.get("posterior_bodyparts", None),
)

config = self._config_func(project_dir)
config = kpms.load_config(project_dir)

# fetch the data and format it for moseq
coordinates, confidences = PoseGroup().fetch_pose_datasets(
key, format_for_moseq=True
)
data, metadata = kpms.format_data(coordinates, confidences, **config())
data, metadata = kpms.format_data(coordinates, confidences, **config)

# either initialize a new model or load an existing one
model_name = self._make_model_name(key)
Expand Down Expand Up @@ -193,9 +193,6 @@ def make(self, key):
}
)

def _config_func(self, project_dir):
return lambda: kpms.load_config(project_dir)

def _make_model_name(self, key: dict = None):
# make a unique model name based on the key
if key is None:
Expand All @@ -215,11 +212,11 @@ def _initialize_model(
model, model_name
"""
# fit pca of data
pca = kpms.fit_pca(**data, **config())
pca = kpms.fit_pca(**data, **config)
kpms.save_pca(pca, project_dir)

# create the model
model = kpms.init_model(data, pca=pca, **config())
model = kpms.init_model(data, pca=pca, **config)
# run the autoregressive fit on the model
num_ar_iters = model_params["num_ar_iters"]
return kpms.fit_model(
Expand All @@ -244,10 +241,10 @@ def analyze_pca(self, key: dict = None):
key = {}
project_dir = (self & key).fetch1("project_dir")
pca = kpms.load_pca(project_dir)
config = self._config_func(project_dir)
config = kpms.load_config(project_dir)
kpms.print_dims_to_explain_variance(pca, 0.9)
kpms.plot_scree(pca, project_dir=project_dir)
kpms.plot_pcs(pca, project_dir=project_dir, **config())
kpms.plot_pcs(pca, project_dir=project_dir, **config)

def fetch_model(self, key: dict = None):
"""Method to fetch the model from the MoseqModel table
Expand Down Expand Up @@ -335,7 +332,7 @@ def make(self, key):

merge_key = {"merge_id": key["pose_merge_id"]}
bodyparts = (PoseGroup & key).fetch1("bodyparts")
config = MoseqModel()._config_func(project_dir)
config = kpms.load_config(project_dir)
num_iters = (MoseqSyllableSelection & key).fetch1("num_iters")

# load data and format for moseq
Expand All @@ -348,7 +345,7 @@ def make(self, key):
bodyparts = bodyparts_df.keys().get_level_values(0).unique().values
datasets = {video_name: bodyparts_df[bodyparts]}
coordinates, confidences = format_dataset_for_moseq(datasets, bodyparts)
data, metadata = kpms.format_data(coordinates, confidences, **config())
data, metadata = kpms.format_data(coordinates, confidences, **config)

# apply model
results = kpms.apply_model(
Expand All @@ -357,7 +354,7 @@ def make(self, key):
metadata,
project_dir,
model_name,
**config(),
**config,
num_iters=num_iters,
)

Expand Down

0 comments on commit c7ea1f5

Please sign in to comment.