Skip to content

Commit

Permalink
refactor: main now (re-)generate dataset when necessary
Browse files Browse the repository at this point in the history
fixes #4

Signed-off-by: Valentin De Matos <[email protected]>
  • Loading branch information
Thytu committed Mar 27, 2024
1 parent db566ab commit 7953156
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,50 @@ def train_model(
trainer.save_model(os.path.join(cfg[step].training_args.output_dir, "final"))


def generate_dataset_if_needed(cfg: DictConfig) -> None:
"""Generate the dataset if does not exist yet or
if the provided params relative to data generation have changed.
To check whether the param have changed between the last generation
and now, a hash is being computed based on `cfg` and save as `.hash`
Args:
cfg (DictConfig): hydra config
"""

PATH_TO_HASH_FILE = "outputs/dataset/.hash"

if not os.path.exists("outputs/dataset/"):
return

if not os.path.exists(PATH_TO_HASH_FILE):
return

cfg_dict = OmegaConf.to_container(cfg)

cfg_dict["model"].pop("decoder")
for key_to_ignore in ("pretraining", "training"):
cfg_dict.pop(key_to_ignore, None)

with open(PATH_TO_HASH_FILE) as f:
last_hash = f.read()

current_hash = hash(json.dumps(cfg_dict, sort_keys=True))

if last_hash == current_hash:
return

generate_dataset()

with open(PATH_TO_HASH_FILE, "w+") as f:
f.write(current_hash)


@hydra.main(version_base=None, config_path="../conf", config_name="default")
def main(cfg : DictConfig):

generate_dataset_if_needed(cfg)

path_to_projector = None

if cfg.get(TrainingStep.PRETRAINING) is not None:
Expand Down

0 comments on commit 7953156

Please sign in to comment.