diff --git a/src/main.py b/src/main.py index 257307f..5fd9fca 100644 --- a/src/main.py +++ b/src/main.py @@ -4,6 +4,7 @@ import wandb import hydra import torch +import hashlib import numpy as np from SMIT import SMIT @@ -178,30 +179,36 @@ def generate_dataset_if_needed(cfg: DictConfig) -> None: PATH_TO_HASH_FILE = "outputs/dataset/.hash" - if not os.path.exists("outputs/dataset/"): - return - - if not os.path.exists(PATH_TO_HASH_FILE): - return - - cfg_dict = OmegaConf.to_container(cfg) + cfg_dict = OmegaConf.to_container( + cfg, + resolve=True, + throw_on_missing=True, + ) cfg_dict["model"].pop("decoder") for key_to_ignore in ("pretraining", "training"): cfg_dict.pop(key_to_ignore, None) - with open(PATH_TO_HASH_FILE) as f: - last_hash = f.read() + if os.path.exists(PATH_TO_HASH_FILE): + with open(PATH_TO_HASH_FILE) as f: + last_hash = f.read() + else: + last_hash = "" + + hasher = hashlib.new( + name='sha256', + data=json.dumps(cfg_dict, sort_keys=True).encode('utf-8'), + ) - current_hash = hash(json.dumps(cfg_dict, sort_keys=True)) + current_hash = hasher.hexdigest() if last_hash == current_hash: return - generate_dataset() + generate_dataset(cfg) with open(PATH_TO_HASH_FILE, "w+") as f: - f.write(current_hash) + f.write(str(current_hash)) @hydra.main(version_base=None, config_path="../conf", config_name="default")