diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 2b3cf86f6..b7bde226a 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -9,7 +9,7 @@ import shutil import subprocess import sys -from typing import Any, Optional +from typing import Any, Dict, Optional from absl import flags from clu import metric_writers @@ -96,14 +96,14 @@ def write_hparams(hparams: spec.Hyperparameters, return hparams -def write_json(name: str, log_dict: dict, indent: int = 2) -> None: +def write_json(name: str, log_dict: Dict, indent: int = 2) -> None: if RANK == 0: with open(name, 'w') as f: f.write(json.dumps(log_dict, indent=indent)) def write_to_csv( - metrics: dict, + metrics: Dict, csv_path: str, ) -> None: try: @@ -120,7 +120,7 @@ def write_to_csv( return -def _get_utilization() -> dict: +def _get_utilization() -> Dict: util_data = {} # CPU @@ -180,7 +180,7 @@ def _get_utilization() -> dict: return util_data -def _get_system_hardware_info() -> dict: +def _get_system_hardware_info() -> Dict: system_hardware_info = {} try: system_hardware_info['cpu_model_name'] = _get_cpu_model_name() @@ -200,7 +200,7 @@ def _get_system_hardware_info() -> dict: return system_hardware_info -def _get_system_software_info() -> dict: +def _get_system_software_info() -> Dict: system_software_info = {} system_software_info['os_platform'] = \ @@ -243,7 +243,7 @@ def _is_primitive_type(item: Any) -> bool: return isinstance(item, primitive) -def _get_workload_properties(workload: spec.Workload) -> dict: +def _get_workload_properties(workload: spec.Workload) -> Dict: workload_properties = {} skip_list = ['param_shapes', 'model_params_types'] keys = [ @@ -262,7 +262,8 @@ def _get_workload_properties(workload: spec.Workload) -> dict: return workload_properties -def get_meta_data(workload: spec.Workload) -> dict: +def get_meta_data(workload: spec.Workload, + rng_seed: Optional[int] = None) -> Dict: meta_data = {} workload_properties = _get_workload_properties(workload) meta_data.update(workload_properties) @@ -272,15 +273,11 @@ def get_meta_data(workload: spec.Workload) -> dict: meta_data.update(system_software_info) system_hardware_info = _get_system_hardware_info() meta_data.update(system_hardware_info) + if rng_seed is not None: + meta_data.update({'rng_seed': rng_seed}) return meta_data -def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): - meta_data = get_meta_data(workload) - meta_data.update({'rng_seed': rng_seed}) - write_json(meta_file_name, meta_data) - - class MetricLogger(object): """Used to log all measurements during training. @@ -308,7 +305,7 @@ def __init__(self, wandb.config.update(hyperparameters._asdict()) def append_scalar_metrics(self, - metrics: dict, + metrics: Dict, global_step: int, preemption_count: Optional[int] = None, is_eval: bool = False) -> None: diff --git a/submission_runner.py b/submission_runner.py index 47730d3fc..56628d602 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -237,7 +237,6 @@ def train_once( else: logging.info('Performing `torch.compile`.') model_params = torch.compile(model_params) - logging.info('Initializing optimizer.') with profiler.profile('Initializing optimizer'): optimizer_state = init_optimizer_state(workload, @@ -284,7 +283,8 @@ def train_once( checkpoint_dir=log_dir) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - logger_utils.save_meta_data(workload, rng_seed, preemption_count) + meta_data = logger_utils.get_meta_data(workload, rng_seed) + logger_utils.write_json(meta_file_name, meta_data) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict())