Skip to content

Commit

Permalink
chore: add hf_model_commit_message to LocusToGeneStep (opentarget…
Browse files Browse the repository at this point in the history
  • Loading branch information
ireneisdoomed authored Nov 7, 2024
1 parent 0d3c01b commit 93de448
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ class LocusToGeneConfig(StepConfig):
)
wandb_run_name: str | None = None
hf_hub_repo_id: str | None = "opentargets/locus_to_gene"
hf_model_commit_message: str | None = "chore: update model"
download_from_hub: bool = True
_target_: str = "gentropy.l2g.LocusToGeneStep"

Expand Down Expand Up @@ -633,6 +634,7 @@ class LocusToGeneEvidenceStepConfig(StepConfig):
locus_to_gene_threshold: float = 0.05
_target_: str = "gentropy.l2g.LocusToGeneEvidenceStep"


@dataclass
class LocusToGeneAssociationsStepConfig(StepConfig):
"""Configuration of the locus to gene association step."""
Expand All @@ -643,6 +645,7 @@ class LocusToGeneAssociationsStepConfig(StepConfig):
indirect_associations_output_path: str = MISSING
_target_: str = "gentropy.l2g.LocusToGeneAssociationsStep"


@dataclass
class StudyLocusValidationStepConfig(StepConfig):
"""Configuration of the study index validation step.
Expand Down
7 changes: 5 additions & 2 deletions src/gentropy/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
gene_interactions_path: str | None = None,
predictions_path: str | None = None,
hf_hub_repo_id: str | None,
hf_model_commit_message: str | None = "chore: update model",
) -> None:
"""Initialise the step and run the logic based on mode.
Expand All @@ -128,6 +129,7 @@ def __init__(
gene_interactions_path (str | None): Path to the gene interactions dataset
predictions_path (str | None): Path to the L2G predictions output dataset
hf_hub_repo_id (str | None): Hugging Face Hub repository ID. If provided, the model will be uploaded to Hugging Face.
hf_model_commit_message (str | None): Commit message when we upload the model to the Hugging Face Hub
Raises:
ValueError: If run_mode is not 'train' or 'predict'
Expand All @@ -146,6 +148,7 @@ def __init__(
self.wandb_run_name = wandb_run_name
self.hf_hub_repo_id = hf_hub_repo_id
self.download_from_hub = download_from_hub
self.hf_model_commit_message = hf_model_commit_message

# Load common inputs
self.credible_set = StudyLocus.from_parquet(
Expand Down Expand Up @@ -219,7 +222,7 @@ def run_train(self) -> None:
).train(self.wandb_run_name)
if trained_model.training_data and trained_model.model and self.model_path:
trained_model.save(self.model_path)
if self.hf_hub_repo_id:
if self.hf_hub_repo_id and self.hf_model_commit_message:
hf_hub_token = access_gcp_secret(
"hfhub-key", "open-targets-genetics-dev"
)
Expand All @@ -231,7 +234,7 @@ def run_train(self) -> None:
"goldStandardSet", "geneId"
).toPandas(),
repo_id=self.hf_hub_repo_id,
commit_message="chore: update model",
commit_message=self.hf_model_commit_message,
)

def _annotate_gold_standards_w_feature_matrix(self) -> L2GFeatureMatrix:
Expand Down

0 comments on commit 93de448

Please sign in to comment.