diff --git a/scoring/__init__.py b/scoring/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/scoring/scoring.py b/scoring/scoring.py new file mode 100644 index 000000000..ba19ec026 --- /dev/null +++ b/scoring/scoring.py @@ -0,0 +1,349 @@ +"""Performance and scoring code. + +The three primary methods exposed by the `scoring` module are: +- `compute_performance_profiles`: generates performance profiles for a set of + submissions over all workloads as defined in the scoring rules: + https://github.com/mlcommons/algorithmic-efficiency/blob/main/RULES.md +- `compute_leaderboard_score`: computes final scores from performance profiles. +- `plot_performance_profiles`: plot performance profiles for a set of + submissions. + +The two primary inputs to `compute_performance_profiles` are +1. A dictionary of pandas DataFrames, where each key is a globally unique + identifier for a submission and each value is a DataFrame containing one row + per trial per workload in that submission. At minimum, this DataFrame should + include a column of np.arrays indicating time (e.g., 'global_step'), a column + of np.arrays indicating performance (e.g., 'valid/accuracy') for each + workload and a column 'workload' that indicates the workload identifier. +2. A dictionary of workload metadata describing each workload in the form: + { + 'workload_identifier': { + 'target': VALUE, + 'metric': 'valid/error_rate', + } + } + The keys in this dictionary should match the workload identifiers used in + the dictionary of submissions. +""" + +import itertools +import operator +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +MIN_EVAL_METRICS = [ + 'ce_loss', + 'error_rate', + 'ctc_loss', + 'wer', + 'l1_loss', +] +MAX_EVAL_METRICS = ['average_precision', 'ssim', 'bleu_score'] + + +def generate_eval_cols(metrics): + splits = ['train', 'valid', 'test'] + return [f'{split}/{col}' for split, col in itertools.product(splits, metrics)] + + +MINIMIZE_REGISTRY = {k: True for k in generate_eval_cols(MIN_EVAL_METRICS)} +MINIMIZE_REGISTRY.update( + {k: False for k in generate_eval_cols(MAX_EVAL_METRICS)}) +MINIMIZE_REGISTRY['train_cost'] = True + + +def check_if_minimized(col_name): + """Guess if the eval metric column name should be minimized or not.""" + for prefix in ['best_', 'final_']: + col_name = col_name.replace(prefix, '') + + for col in MINIMIZE_REGISTRY: + if col in col_name: + return MINIMIZE_REGISTRY[col] + + raise ValueError(f'Column {col_name} not found in `MINIMIZE_REGISTRY` as ' + 'either a column name or a substring of a column name.') + + +def get_index_that_reaches_best(workload_df, metric_col): + """Get the eval index in which a workload reaches the best on metric_col. + + Args: + workload_df: A subset of a submission's trials DataFrame that + includes only the trials in a single workload. + metric_col: Name of array column in workload_df (e.g., `valid/l1_loss`). + + Returns: + Tuple of trial index, time index, and best value where the workload + reached the best metric_col. Return (-1, -1, -1) if no undiverged trials. + """ + is_minimized = check_if_minimized(metric_col) + series = workload_df[metric_col] + + series = series[series != np.nan] + + op = np.min if is_minimized else np.max + best = series.apply(op) + + op_idx = np.argmin if is_minimized else np.argmax + best_idx = series.apply(op_idx) + + if best.empty: + return -1, -1, -1 + else: + trial = best.idxmin() if is_minimized else best.idxmax() + return trial, best_idx[trial], best[trial] + + +def get_index_that_reaches_target(workload_df, metric_col, target): + """Get the eval index in which a workload reaches the target metric_col. + + Args: + workload_df: A subset of a submission's trials DataFrame that + includes only the trials in a single workload. + metric_col: Name of array column in workload_df (e.g., `valid/l1_loss`). + target: Target value for metric_col. + + Returns: + Tuple of trial index and time index where the workload reached the target + metric_col. Return (-1, -1) if not reached. + """ + is_minimized = check_if_minimized(metric_col) + series = workload_df[metric_col] + + series = series[series != np.nan] + + op = operator.le if is_minimized else operator.ge + target_reached = series.apply(lambda x: op(x, target)) + + # Remove trials that never reach the target + target_reached = target_reached[target_reached.apply(np.any)] + + # If we have no trials that have reached the target, return -1. Else, return + # the eval index of the earliest point the target is reached. + if target_reached.empty: + return -1, -1 + else: + index_reached = target_reached.apply(np.argmax) + trial = index_reached.idxmin() + return trial, index_reached[trial] + + +def get_times_for_submission(submission, + submission_tag, + workload_metadata, + time_col='global_step', + verbosity=1): + """Get times to target for each workload in a submission. + + Args: + submission: A DataFrame containing one row for each trial in each workload + for a given submission. + submission_tag: Globally unique identified for a submission. + workload_metadata: Dictionary keyed by workload names with value of + dictionary with `target` and `metric` as keys. + time_col: A string indicating which column to use for time. + verbosity: Debug level of information; choice of (1, 2, 3). + + Returns: + DataFrame with columns `submission`, `workload`, and time_col. + """ + workloads = [] + submission_name = submission_tag.split('.')[1] + + for workload, group in submission.groupby('workload'): + metric = workload_metadata[workload]['metric'] + target = workload_metadata[workload]['target'] + trial_idx, time_idx = get_index_that_reaches_target(group, metric, target) + if time_idx > -1: + time_val = group[time_col].loc[trial_idx][time_idx] + else: + time_val = float('inf') + + workloads.append({ + 'submission': submission_name, + 'workload': workload, + time_col: time_val, + }) + + if verbosity > 0: + print(' hparams:') + if time_idx > -1: + hparams = group.loc[trial_idx, 'hparams'] + for key, val in hparams.items(): + print(f' - {key}: {val}') + else: + print('Submission did not reach target') + df = pd.DataFrame.from_records(workloads) + df = df.pivot(index='submission', columns='workload', values=time_col) + + return df + + +def compute_performance_profiles(results, + workload_metadata, + time_col='global_step', + min_tau=1.0, + max_tau=None, + reference_submission_tag=None, + num_points=100, + scale='linear', + verbosity=0): + """Compute performance profiles for a set of submission by some time column. + + Args: + results: Dict where keys are submission names and values are a DataFrame of + trials where each row is a trial and each column is a field for a given + trial. Results should contain keys for each workload's metric, time_col, + 'workload'. See file header comment for more details. + workload_metadata: Dictionary keyed by workload names with value of + dictionary with `target` and `metric` as keys. + time_col: A string indicating which column to use for time. + min_tau: Minimum tau to use for plotting. + max_tau: Maximum tau to use for plotting. + reference_submission_tag: If specified, must be an element of + `submission_tags`. Used as the denominator for computing tau. Otherwise, + the minimum time to target is computed per-workload and used as the + denominator for tau. + num_points: Number of points to use for plotting. + scale: Linear or log scale for the x-axis. + verbosity: Debug level of information; choice of (1, 2, 3). + + Returns: + A DataFrame of performance profiles for the set of submissions given in + `results` based on `time_col`. Each row represents a submission and each + column represents rho(tau) for some value of tau (df.volumns are the + different values of tau). + """ + dfs = [] + + for submission_tag, result in results.items(): + print(f'\nComputing performance profile with respect to `{time_col}` for ' + f'{submission_tag}') + dfs.append( + get_times_for_submission(result, + submission_tag, + workload_metadata, + time_col, + verbosity)) + df = pd.concat(dfs) + + if verbosity > 0: + print(f'\n`{time_col}` to reach target:') + with pd.option_context('display.max_rows', + None, + 'display.max_columns', + None, + 'display.width', + 1000): + print(df) + + # Divide by the fastest. + if reference_submission_tag is None: + df.update(df.div(df.min(axis=0), axis=1)) + else: + df.update(df.div(df.loc[reference_submission_tag, :], axis=1)) + + if verbosity > 0: + print(f'\n`{time_col}` to reach target normalized to best:') + with pd.option_context('display.max_rows', + None, + 'display.max_columns', + None, + 'display.width', + 1000): + print(df) + + # If no max_tau is supplied, choose the value of tau that would plot all non + # inf or nan data. + if max_tau is None: + max_tau = df.replace(float('inf'), -1).replace(np.nan, -1).values.max() + + if scale == 'linear': + points = np.linspace(min_tau, max_tau, num=num_points) + elif scale == 'log': + points = np.logspace( + np.log10(min_tau), np.log10(max_tau), num=num_points, base=10.0) + + def rho(r, tau): + return (r <= tau).sum(axis=1) / len(r.columns) + + perf_df = pd.concat([rho(df, tau) for tau in points], axis=1) + + cols = points + if scale == 'log': + cols = np.log10(points) + perf_df.columns = cols + + return perf_df + + +def compute_leaderboard_score(df, normalize=False): + """Compute leaderboard score by taking integral of performance profile. + + Args: + df: pd.DataFrame returned from `compute_performance_profiles`. + normalize: divide by the range of the performance profile's tau. + + Returns: + pd.DataFrame with one column of scores indexed by submission. + """ + scores = np.trapz(df, x=df.columns) + if normalize: + scores /= df.columns.max() - df.columns.min() + return pd.DataFrame(scores, columns=['score'], index=df.index) + + +def maybe_save_figure(save_dir, name, ext='pdf'): + """Maybe save the current matplotlib.pyplot figure.""" + if save_dir: + path = os.path.join(save_dir, f'{name}.{ext}') + with open(path, 'wb') as fout: + plt.savefig(fout, format=ext) + + +def maybe_save_df_to_csv(save_dir, df, path, **to_csv_kwargs): + if save_dir: + path = os.path.join(save_dir, path) + with open(path, 'w') as fout: + df.to_csv(fout, **to_csv_kwargs) + + +def plot_performance_profiles(perf_df, + df_col, + scale='linear', + save_dir=None, + figsize=(30, 10), + font_size=18): + """Plot performance profiles. + + Args: + perf_df: A DataFrame of performance profiles where each row represents a + submission and each column represents rho(tau) for some value of tau + (df.volumns are the different values of tau). + df_col: The column in the original submission results DataFrame used to + compute the performance profile. This argument is only used for axis + and file naming. + scale: Whether or not the data in perf_df is on a linear or log scale. This + argument is only used for axis and file naming. + save_dir: If a valid directory is provided, save both the plot and perf_df + to the provided directory. + figsize: The size of the plot. + font_size: The font size to use for the legend. + + Returns: + None. If a valid save_dir is provided, save both the plot and perf_df. + """ + fig = perf_df.T.plot(figsize=figsize) + df_col_display = f'log10({df_col})' if scale == 'log' else df_col + fig.set_xlabel( + f'Ratio of `{df_col_display}` to best submission', size=font_size) + fig.set_ylabel('Proportion of workloads', size=font_size) + fig.legend(prop={'size': font_size}, bbox_to_anchor=(1.0, 1.0)) + maybe_save_figure(save_dir, f'performance_profile_by_{df_col_display}') + maybe_save_df_to_csv(save_dir, + perf_df, + f'performance_profile_{df_col_display}.csv') diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py new file mode 100644 index 000000000..bb346acdc --- /dev/null +++ b/scoring/scoring_utils.py @@ -0,0 +1,129 @@ +import json +import os +import re + +import pandas as pd + +trial_line_regex = '(.*) --- Tuning run (\d+)/(\d+) ---' +metrics_line_regex = '(.*) Metrics: ({.*})' + + +#### File IO helper functions ### +def get_logfile_paths(logdir): + """Gets all files ending in .log in logdir + """ + filenames = os.listdir(logdir) + logfile_paths = [] + for f in filenames: + if f.endswith(".log"): + f = os.path.join(logdir, f) + logfile_paths.append(f) + return logfile_paths + + +### Logfile reading helper functions ### +def decode_metrics_line(line): + """Convert metrics line to dict. + Args: + line: str + + Returns: + dict_of_lists: dict where keys are metric names and vals + are lists of values. + e.g. {'loss':[5.1, 3.2, 1.0], + 'step':[100, 200, 300]} + """ + eval_results = [] + dict_str = re.match(metrics_line_regex, line).group(2) + dict_str = dict_str.replace("'", "\"") + dict_str = dict_str.replace("(", "") + dict_str = dict_str.replace(")", "") + dict_str = dict_str.replace("DeviceArray", "") + dict_str = dict_str.replace(", dtype=float32", "") + dict_str = dict_str.replace("nan", "0") + metrics_dict = json.loads(dict_str) + for item in metrics_dict['eval_results']: + if isinstance(item, dict): + eval_results.append(item) + + keys = eval_results[0].keys() + + dict_of_lists = {} + for key in keys: + dict_of_lists[key] = [] + + for eval_results_dict in eval_results: + for key in eval_results_dict.keys(): + val = eval_results_dict[key] + dict_of_lists[key].append(val) + + return dict_of_lists + + +def get_trials_dict(logfile): + """Get a dict of dicts with metrics for each + tuning run. + + Returns: + trials_dict: Dict of dicts where outer dict keys + are trial indices and inner dict key-value pairs + are metrics and list of values. + e.g. {'trial_0': {'loss':[5.1, 3.2, 1.0], + 'step':[100, 200, 300]}, + 'trial_1': {'loss':[5.1, 3.2, 1.0], + 'step':[100, 200, 300]}} + """ + trial = 0 + metrics_lines = {} + with open(logfile, 'r') as f: + for line in f: + if re.match(trial_line_regex, line): + trial = re.match(trial_line_regex, line).group(2) + if re.match(metrics_line_regex, line): + metrics_lines[trial] = decode_metrics_line(line) + if len(metrics_lines) == 0: + raise ValueError(f"Log file does not have a metrics line {logfile}") + return metrics_lines + + +### Results formatting helper functions ### +def get_trials_df_dict(logfile): + """Get a dict with dataframes with metrics for each + tuning run. + Preferable format for saving dataframes for tables. + Args: + logfile: str path to logfile. + + Returns: + DataFrame where indices are index of eval and + columns are metric names. + """ + trials_dict = get_trials_dict(logfile) + trials_df_dict = {} + for trial in trials_dict: + metrics = trials_dict[trial] + trials_df_dict[trial] = pd.DataFrame(metrics) + return trials_df_dict + + +def get_trials_df(logfile): + """Gets a df of per trial results from a logfile. + The output df can be provided as input to + scoring.compute_performance_profiles. + Args: + experiment_dir: str + + Returns: + df: DataFrame where indices are trials, columns are + metric names and values are lists. + e.g + +---------+-----------------+-----------------+ + | | loss | step | + |---------+-----------------+-----------------| + | trial_0 | [5.1, 3.2, 1.0] | [100, 200, 300] | + | trial_1 | [5.1, 3.2, 1.0] | [100, 200, 300] | + +---------+-----------------+-----------------+ + """ + trials_dict = get_trials_dict(logfile) + df = pd.DataFrame(trials_dict).transpose() + return df diff --git a/scoring/test_data/adamw_fastmri_jax_04-18-2023-13-10-58.log b/scoring/test_data/adamw_fastmri_jax_04-18-2023-13-10-58.log new file mode 100644 index 000000000..5038fdae0 --- /dev/null +++ b/scoring/test_data/adamw_fastmri_jax_04-18-2023-13-10-58.log @@ -0,0 +1,228 @@ +I0418 13:11:18.847079 139811861223232 logger_utils.py:67] Creating experiment directory at /experiment_runs/timing_v3/timing_adamw/fastmri_jax. +I0418 13:11:19.000787 139811861223232 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: +I0418 13:11:19.900017 139811861223232 xla_bridge.py:345] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Host Interpreter +I0418 13:11:19.900953 139811861223232 xla_bridge.py:345] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client' +I0418 13:11:19.904862 139811861223232 submission_runner.py:528] Using RNG seed 592249298 +I0418 13:11:22.495646 139811861223232 submission_runner.py:537] --- Tuning run 1/1 --- +I0418 13:11:22.495854 139811861223232 submission_runner.py:542] Creating tuning directory at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1. +I0418 13:11:22.496118 139811861223232 logger_utils.py:83] Saving hparams to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/hparams.json. +I0418 13:11:22.619274 139811861223232 submission_runner.py:232] Initializing dataset. +I0418 13:11:26.460471 139811861223232 submission_runner.py:239] Initializing model. +I0418 13:11:33.442139 139811861223232 submission_runner.py:249] Initializing optimizer. +I0418 13:11:33.885006 139811861223232 submission_runner.py:256] Initializing metrics bundle. +I0418 13:11:33.885203 139811861223232 submission_runner.py:273] Initializing checkpoint and logger. +I0418 13:11:33.887343 139811861223232 checkpoints.py:466] Found no checkpoint files in /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1 with prefix checkpoint_ +I0418 13:11:33.887573 139811861223232 logger_utils.py:230] Unable to record workload.train_mean information. Continuing without it. +I0418 13:11:33.887642 139811861223232 logger_utils.py:230] Unable to record workload.train_stddev information. Continuing without it. +I0418 13:11:34.773985 139811861223232 submission_runner.py:294] Saving meta data to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/meta_data_0.json. +I0418 13:11:34.774941 139811861223232 submission_runner.py:297] Saving flags to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/flags_0.json. +I0418 13:11:34.781248 139811861223232 submission_runner.py:309] Starting training loop. +I0418 13:12:39.695030 139635653076736 logging_writer.py:48] [0] global_step=0, grad_norm=4.749215602874756, loss=0.7997596263885498 +I0418 13:12:39.706240 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:14:08.760860 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:15:11.717901 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:16:12.381375 139811861223232 submission_runner.py:406] Time since start: 277.60s, Step: 1, {'train/ssim': 0.21045025757380895, 'train/loss': 0.8323062488010952, 'validation/ssim': 0.20073387820061903, 'validation/loss': 0.8397238092202448, 'validation/num_examples': 3554, 'test/ssim': 0.22385435256169714, 'test/loss': 0.8394407577535954, 'test/num_examples': 3581, 'score': 64.92480111122131, 'total_duration': 277.60004019737244, 'accumulated_submission_time': 64.92480111122131, 'accumulated_eval_time': 212.67506647109985, 'accumulated_logging_time': 0} +I0418 13:16:12.398734 139607727380224 logging_writer.py:48] [1] accumulated_eval_time=212.675066, accumulated_logging_time=0, accumulated_submission_time=64.924801, global_step=1, preemption_count=0, score=64.924801, test/loss=0.839441, test/num_examples=3581, test/ssim=0.223854, total_duration=277.600040, train/loss=0.832306, train/ssim=0.210450, validation/loss=0.839724, validation/num_examples=3554, validation/ssim=0.200734 +I0418 13:16:12.435780 139811861223232 checkpoints.py:356] Saving checkpoint at step: 1 +I0418 13:16:12.667325 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_1 +I0418 13:16:12.668052 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_1. +I0418 13:16:34.214733 139607643518720 logging_writer.py:48] [100] global_step=100, grad_norm=0.2432398796081543, loss=0.33436650037765503 +I0418 13:16:57.787417 139607609947904 logging_writer.py:48] [200] global_step=200, grad_norm=0.1648881435394287, loss=0.36671581864356995 +I0418 13:17:21.661577 139607643518720 logging_writer.py:48] [300] global_step=300, grad_norm=0.28795504570007324, loss=0.2782801687717438 +I0418 13:17:32.762570 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:17:34.622865 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:17:35.969232 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:17:37.315605 139811861223232 submission_runner.py:406] Time since start: 362.53s, Step: 331, {'train/ssim': 0.7070232800074986, 'train/loss': 0.30231707436697824, 'validation/ssim': 0.6850338636923186, 'validation/loss': 0.32094466209948297, 'validation/num_examples': 3554, 'test/ssim': 0.7031452484684795, 'test/loss': 0.32283779763203363, 'test/num_examples': 3581, 'score': 145.01519441604614, 'total_duration': 362.53429913520813, 'accumulated_submission_time': 145.01519441604614, 'accumulated_eval_time': 217.22807550430298, 'accumulated_logging_time': 0.28686976432800293} +I0418 13:17:37.324599 139607609947904 logging_writer.py:48] [331] accumulated_eval_time=217.228076, accumulated_logging_time=0.286870, accumulated_submission_time=145.015194, global_step=331, preemption_count=0, score=145.015194, test/loss=0.322838, test/num_examples=3581, test/ssim=0.703145, total_duration=362.534299, train/loss=0.302317, train/ssim=0.707023, validation/loss=0.320945, validation/num_examples=3554, validation/ssim=0.685034 +I0418 13:17:37.406208 139811861223232 checkpoints.py:356] Saving checkpoint at step: 331 +I0418 13:17:37.661238 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_331 +I0418 13:17:37.661890 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_331. +I0418 13:17:58.928837 139607643518720 logging_writer.py:48] [400] global_step=400, grad_norm=0.14074426889419556, loss=0.2704832851886749 +I0418 13:18:36.867310 139607601555200 logging_writer.py:48] [500] global_step=500, grad_norm=0.4256458878517151, loss=0.20842526853084564 +I0418 13:18:57.943878 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:18:59.347492 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:19:00.696888 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:19:02.044154 139811861223232 submission_runner.py:406] Time since start: 447.26s, Step: 560, {'train/ssim': 0.7200777871268136, 'train/loss': 0.290642329624721, 'validation/ssim': 0.6985314604802687, 'validation/loss': 0.3087669977929797, 'validation/num_examples': 3554, 'test/ssim': 0.7159312359283371, 'test/loss': 0.31089522365217465, 'test/num_examples': 3581, 'score': 225.29421615600586, 'total_duration': 447.26284670829773, 'accumulated_submission_time': 225.29421615600586, 'accumulated_eval_time': 221.3283338546753, 'accumulated_logging_time': 0.633378267288208} +I0418 13:19:02.054116 139607643518720 logging_writer.py:48] [560] accumulated_eval_time=221.328334, accumulated_logging_time=0.633378, accumulated_submission_time=225.294216, global_step=560, preemption_count=0, score=225.294216, test/loss=0.310895, test/num_examples=3581, test/ssim=0.715931, total_duration=447.262847, train/loss=0.290642, train/ssim=0.720078, validation/loss=0.308767, validation/num_examples=3554, validation/ssim=0.698531 +I0418 13:19:02.141814 139811861223232 checkpoints.py:356] Saving checkpoint at step: 560 +I0418 13:19:02.376993 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_560 +I0418 13:19:02.377660 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_560. +I0418 13:19:14.267865 139607601555200 logging_writer.py:48] [600] global_step=600, grad_norm=0.44230204820632935, loss=0.26166975498199463 +I0418 13:19:51.037283 139601586935552 logging_writer.py:48] [700] global_step=700, grad_norm=0.5681076645851135, loss=0.2686598598957062 +I0418 13:20:22.674819 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:20:24.081235 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:20:25.428304 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:20:26.775214 139811861223232 submission_runner.py:406] Time since start: 531.99s, Step: 787, {'train/ssim': 0.7264369555882045, 'train/loss': 0.2863526003701346, 'validation/ssim': 0.7047054567696609, 'validation/loss': 0.30473246061875703, 'validation/num_examples': 3554, 'test/ssim': 0.7218843538379642, 'test/loss': 0.30675468455476823, 'test/num_examples': 3581, 'score': 305.5884253978729, 'total_duration': 531.9939126968384, 'accumulated_submission_time': 305.5884253978729, 'accumulated_eval_time': 225.4286892414093, 'accumulated_logging_time': 0.9671237468719482} +I0418 13:20:26.786030 139607601555200 logging_writer.py:48] [787] accumulated_eval_time=225.428689, accumulated_logging_time=0.967124, accumulated_submission_time=305.588425, global_step=787, preemption_count=0, score=305.588425, test/loss=0.306755, test/num_examples=3581, test/ssim=0.721884, total_duration=531.993913, train/loss=0.286353, train/ssim=0.726437, validation/loss=0.304732, validation/num_examples=3554, validation/ssim=0.704705 +I0418 13:20:26.867157 139811861223232 checkpoints.py:356] Saving checkpoint at step: 787 +I0418 13:20:27.119520 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_787 +I0418 13:20:27.120136 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_787. +I0418 13:20:28.694146 139601586935552 logging_writer.py:48] [800] global_step=800, grad_norm=0.24688591063022614, loss=0.33306989073753357 +I0418 13:21:06.572049 139580120491776 logging_writer.py:48] [900] global_step=900, grad_norm=0.27356621623039246, loss=0.29250115156173706 +I0418 13:21:39.481042 139601586935552 logging_writer.py:48] [1000] global_step=1000, grad_norm=0.1510585993528366, loss=0.29647883772850037 +I0418 13:21:47.153610 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:21:48.558691 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:21:49.905932 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:21:51.251713 139811861223232 submission_runner.py:406] Time since start: 616.47s, Step: 1034, {'train/ssim': 0.7311224256243024, 'train/loss': 0.281079752104623, 'validation/ssim': 0.7095664243853756, 'validation/loss': 0.298722438350714, 'validation/num_examples': 3554, 'test/ssim': 0.7266145187927604, 'test/loss': 0.30085532394277087, 'test/num_examples': 3581, 'score': 385.6187229156494, 'total_duration': 616.4704086780548, 'accumulated_submission_time': 385.6187229156494, 'accumulated_eval_time': 229.52676105499268, 'accumulated_logging_time': 1.3123102188110352} +I0418 13:21:51.260128 139580120491776 logging_writer.py:48] [1034] accumulated_eval_time=229.526761, accumulated_logging_time=1.312310, accumulated_submission_time=385.618723, global_step=1034, preemption_count=0, score=385.618723, test/loss=0.300855, test/num_examples=3581, test/ssim=0.726615, total_duration=616.470409, train/loss=0.281080, train/ssim=0.731122, validation/loss=0.298722, validation/num_examples=3554, validation/ssim=0.709566 +I0418 13:21:51.325284 139811861223232 checkpoints.py:356] Saving checkpoint at step: 1034 +I0418 13:21:51.540192 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_1034 +I0418 13:21:51.540703 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_1034. +I0418 13:22:05.197393 139601586935552 logging_writer.py:48] [1100] global_step=1100, grad_norm=0.22205762565135956, loss=0.3316890001296997 +I0418 13:22:28.457181 139562427709184 logging_writer.py:48] [1200] global_step=1200, grad_norm=0.16201959550380707, loss=0.2782382071018219 +I0418 13:22:52.147155 139601586935552 logging_writer.py:48] [1300] global_step=1300, grad_norm=0.4414421617984772, loss=0.363120973110199 +I0418 13:23:11.787177 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:23:13.190361 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:23:14.542059 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:23:15.888787 139811861223232 submission_runner.py:406] Time since start: 701.11s, Step: 1384, {'train/ssim': 0.7352303096226284, 'train/loss': 0.278595038822719, 'validation/ssim': 0.7136715452658976, 'validation/loss': 0.29653461823385624, 'validation/num_examples': 3554, 'test/ssim': 0.7307323891502024, 'test/loss': 0.2984414610640359, 'test/num_examples': 3581, 'score': 465.86059069633484, 'total_duration': 701.1074573993683, 'accumulated_submission_time': 465.86059069633484, 'accumulated_eval_time': 233.62830877304077, 'accumulated_logging_time': 1.601564884185791} +I0418 13:23:15.896825 139562427709184 logging_writer.py:48] [1384] accumulated_eval_time=233.628309, accumulated_logging_time=1.601565, accumulated_submission_time=465.860591, global_step=1384, preemption_count=0, score=465.860591, test/loss=0.298441, test/num_examples=3581, test/ssim=0.730732, total_duration=701.107457, train/loss=0.278595, train/ssim=0.735230, validation/loss=0.296535, validation/num_examples=3554, validation/ssim=0.713672 +I0418 13:23:15.934105 139811861223232 checkpoints.py:356] Saving checkpoint at step: 1384 +I0418 13:23:16.147197 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_1384 +I0418 13:23:16.147722 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_1384. +I0418 13:23:17.921429 139601586935552 logging_writer.py:48] [1400] global_step=1400, grad_norm=0.48134860396385193, loss=0.3586328625679016 +I0418 13:23:41.050539 139561673946880 logging_writer.py:48] [1500] global_step=1500, grad_norm=0.09520641714334488, loss=0.3314594626426697 +I0418 13:24:05.158188 139601586935552 logging_writer.py:48] [1600] global_step=1600, grad_norm=0.28729698061943054, loss=0.3001077473163605 +I0418 13:24:28.470937 139561673946880 logging_writer.py:48] [1700] global_step=1700, grad_norm=0.09772060811519623, loss=0.34218305349349976 +I0418 13:24:36.272807 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:24:37.679940 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:24:39.031306 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:24:40.381394 139811861223232 submission_runner.py:406] Time since start: 785.60s, Step: 1735, {'train/ssim': 0.7361899784633091, 'train/loss': 0.27769272668021067, 'validation/ssim': 0.7143108859515687, 'validation/loss': 0.29573109744302195, 'validation/num_examples': 3554, 'test/ssim': 0.7314950132862678, 'test/loss': 0.29746227376256634, 'test/num_examples': 3581, 'score': 545.9810972213745, 'total_duration': 785.6000730991364, 'accumulated_submission_time': 545.9810972213745, 'accumulated_eval_time': 237.73684239387512, 'accumulated_logging_time': 1.8607571125030518} +I0418 13:24:40.389561 139601586935552 logging_writer.py:48] [1735] accumulated_eval_time=237.736842, accumulated_logging_time=1.860757, accumulated_submission_time=545.981097, global_step=1735, preemption_count=0, score=545.981097, test/loss=0.297462, test/num_examples=3581, test/ssim=0.731495, total_duration=785.600073, train/loss=0.277693, train/ssim=0.736190, validation/loss=0.295731, validation/num_examples=3554, validation/ssim=0.714311 +I0418 13:24:40.426127 139811861223232 checkpoints.py:356] Saving checkpoint at step: 1735 +I0418 13:24:40.637619 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_1735 +I0418 13:24:40.638168 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_1735. +I0418 13:24:53.912803 139561673946880 logging_writer.py:48] [1800] global_step=1800, grad_norm=0.24099458754062653, loss=0.3218531012535095 +I0418 13:25:17.370750 139561640376064 logging_writer.py:48] [1900] global_step=1900, grad_norm=0.23585671186447144, loss=0.34930768609046936 +I0418 13:25:40.919478 139561673946880 logging_writer.py:48] [2000] global_step=2000, grad_norm=0.1534263789653778, loss=0.2778470814228058 +I0418 13:26:00.795663 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:26:02.201891 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:26:03.550091 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:26:04.900637 139811861223232 submission_runner.py:406] Time since start: 870.12s, Step: 2085, {'train/ssim': 0.736290454864502, 'train/loss': 0.2762721947261265, 'validation/ssim': 0.7143876178249859, 'validation/loss': 0.29404577835933104, 'validation/num_examples': 3554, 'test/ssim': 0.7318026945598296, 'test/loss': 0.2957445968632191, 'test/num_examples': 3581, 'score': 626.1339099407196, 'total_duration': 870.1193282604218, 'accumulated_submission_time': 626.1339099407196, 'accumulated_eval_time': 241.8417820930481, 'accumulated_logging_time': 2.117722988128662} +I0418 13:26:04.908618 139561640376064 logging_writer.py:48] [2085] accumulated_eval_time=241.841782, accumulated_logging_time=2.117723, accumulated_submission_time=626.133910, global_step=2085, preemption_count=0, score=626.133910, test/loss=0.295745, test/num_examples=3581, test/ssim=0.731803, total_duration=870.119328, train/loss=0.276272, train/ssim=0.736290, validation/loss=0.294046, validation/num_examples=3554, validation/ssim=0.714388 +I0418 13:26:04.945039 139811861223232 checkpoints.py:356] Saving checkpoint at step: 2085 +I0418 13:26:05.157962 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_2085 +I0418 13:26:05.158536 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_2085. +I0418 13:26:06.700137 139561673946880 logging_writer.py:48] [2100] global_step=2100, grad_norm=0.10490967333316803, loss=0.3074108362197876 +I0418 13:26:30.318451 139561631983360 logging_writer.py:48] [2200] global_step=2200, grad_norm=0.21701167523860931, loss=0.21758627891540527 +I0418 13:26:53.750888 139561673946880 logging_writer.py:48] [2300] global_step=2300, grad_norm=0.13879521191120148, loss=0.3594713807106018 +I0418 13:27:17.112094 139561631983360 logging_writer.py:48] [2400] global_step=2400, grad_norm=0.38092124462127686, loss=0.3217669427394867 +I0418 13:27:25.257875 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:27:26.659041 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:27:28.008491 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:27:29.363351 139811861223232 submission_runner.py:406] Time since start: 954.58s, Step: 2436, {'train/ssim': 0.7389028412955148, 'train/loss': 0.2739771604537964, 'validation/ssim': 0.7167958445809299, 'validation/loss': 0.29225274613551633, 'validation/num_examples': 3554, 'test/ssim': 0.7340429796582658, 'test/loss': 0.293848228933957, 'test/num_examples': 3581, 'score': 706.2287080287933, 'total_duration': 954.5820419788361, 'accumulated_submission_time': 706.2287080287933, 'accumulated_eval_time': 245.9472210407257, 'accumulated_logging_time': 2.375852584838867} +I0418 13:27:29.371792 139561673946880 logging_writer.py:48] [2436] accumulated_eval_time=245.947221, accumulated_logging_time=2.375853, accumulated_submission_time=706.228708, global_step=2436, preemption_count=0, score=706.228708, test/loss=0.293848, test/num_examples=3581, test/ssim=0.734043, total_duration=954.582042, train/loss=0.273977, train/ssim=0.738903, validation/loss=0.292253, validation/num_examples=3554, validation/ssim=0.716796 +I0418 13:27:29.408613 139811861223232 checkpoints.py:356] Saving checkpoint at step: 2436 +I0418 13:27:29.623686 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_2436 +I0418 13:27:29.624253 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_2436. +I0418 13:27:42.525188 139561631983360 logging_writer.py:48] [2500] global_step=2500, grad_norm=0.10224933922290802, loss=0.20601877570152283 +I0418 13:28:06.205456 139561522943744 logging_writer.py:48] [2600] global_step=2600, grad_norm=0.18922226130962372, loss=0.29663658142089844 +I0418 13:28:29.743257 139561631983360 logging_writer.py:48] [2700] global_step=2700, grad_norm=0.14750812947750092, loss=0.3225230574607849 +I0418 13:28:49.819688 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:28:51.226343 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:28:52.576523 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:28:53.929800 139811861223232 submission_runner.py:406] Time since start: 1039.15s, Step: 2787, {'train/ssim': 0.740485531943185, 'train/loss': 0.27365691321236746, 'validation/ssim': 0.719278536288337, 'validation/loss': 0.2916071542408202, 'validation/num_examples': 3554, 'test/ssim': 0.7363298294165387, 'test/loss': 0.29325788722162105, 'test/num_examples': 3581, 'score': 786.4194860458374, 'total_duration': 1039.1484916210175, 'accumulated_submission_time': 786.4194860458374, 'accumulated_eval_time': 250.05729293823242, 'accumulated_logging_time': 2.636970281600952} +I0418 13:28:53.938604 139561522943744 logging_writer.py:48] [2787] accumulated_eval_time=250.057293, accumulated_logging_time=2.636970, accumulated_submission_time=786.419486, global_step=2787, preemption_count=0, score=786.419486, test/loss=0.293258, test/num_examples=3581, test/ssim=0.736330, total_duration=1039.148492, train/loss=0.273657, train/ssim=0.740486, validation/loss=0.291607, validation/num_examples=3554, validation/ssim=0.719279 +I0418 13:28:53.973526 139811861223232 checkpoints.py:356] Saving checkpoint at step: 2787 +I0418 13:28:54.186844 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_2787 +I0418 13:28:54.187454 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_2787. +I0418 13:28:55.308030 139561631983360 logging_writer.py:48] [2800] global_step=2800, grad_norm=0.1064099594950676, loss=0.3366606533527374 +I0418 13:29:18.703960 139561623590656 logging_writer.py:48] [2900] global_step=2900, grad_norm=0.15394526720046997, loss=0.31380361318588257 +I0418 13:29:41.974687 139561631983360 logging_writer.py:48] [3000] global_step=3000, grad_norm=0.09457140415906906, loss=0.3099451959133148 +I0418 13:30:05.674211 139561623590656 logging_writer.py:48] [3100] global_step=3100, grad_norm=0.1710084080696106, loss=0.21425148844718933 +I0418 13:30:14.377082 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:30:15.783730 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:30:17.136895 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:30:18.489351 139811861223232 submission_runner.py:406] Time since start: 1123.71s, Step: 3137, {'train/ssim': 0.7372136116027832, 'train/loss': 0.276213492665972, 'validation/ssim': 0.7147835735263084, 'validation/loss': 0.29487282705226503, 'validation/num_examples': 3554, 'test/ssim': 0.7319260261405682, 'test/loss': 0.29648540446758936, 'test/num_examples': 3581, 'score': 866.6045970916748, 'total_duration': 1123.708041191101, 'accumulated_submission_time': 866.6045970916748, 'accumulated_eval_time': 254.16952347755432, 'accumulated_logging_time': 2.894801139831543} +I0418 13:30:18.498954 139561631983360 logging_writer.py:48] [3137] accumulated_eval_time=254.169523, accumulated_logging_time=2.894801, accumulated_submission_time=866.604597, global_step=3137, preemption_count=0, score=866.604597, test/loss=0.296485, test/num_examples=3581, test/ssim=0.731926, total_duration=1123.708041, train/loss=0.276213, train/ssim=0.737214, validation/loss=0.294873, validation/num_examples=3554, validation/ssim=0.714784 +I0418 13:30:18.536054 139811861223232 checkpoints.py:356] Saving checkpoint at step: 3137 +I0418 13:30:18.748771 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_3137 +I0418 13:30:18.749326 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_3137. +I0418 13:30:31.629266 139561623590656 logging_writer.py:48] [3200] global_step=3200, grad_norm=0.04792667180299759, loss=0.22304023802280426 +I0418 13:30:55.307512 139561759127296 logging_writer.py:48] [3300] global_step=3300, grad_norm=0.08638576418161392, loss=0.25993800163269043 +I0418 13:31:18.809254 139561623590656 logging_writer.py:48] [3400] global_step=3400, grad_norm=0.12470752745866776, loss=0.2743197977542877 +I0418 13:31:38.865733 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:31:40.273889 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:31:41.623636 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:31:42.977093 139811861223232 submission_runner.py:406] Time since start: 1208.20s, Step: 3488, {'train/ssim': 0.7427416528974261, 'train/loss': 0.2716409649167742, 'validation/ssim': 0.7203257168419387, 'validation/loss': 0.2901462947778911, 'validation/num_examples': 3554, 'test/ssim': 0.7374846738864842, 'test/loss': 0.2917485581999616, 'test/num_examples': 3581, 'score': 946.7163908481598, 'total_duration': 1208.1957819461823, 'accumulated_submission_time': 946.7163908481598, 'accumulated_eval_time': 258.28086018562317, 'accumulated_logging_time': 3.155013084411621} +I0418 13:31:42.986239 139561759127296 logging_writer.py:48] [3488] accumulated_eval_time=258.280860, accumulated_logging_time=3.155013, accumulated_submission_time=946.716391, global_step=3488, preemption_count=0, score=946.716391, test/loss=0.291749, test/num_examples=3581, test/ssim=0.737485, total_duration=1208.195782, train/loss=0.271641, train/ssim=0.742742, validation/loss=0.290146, validation/num_examples=3554, validation/ssim=0.720326 +I0418 13:31:43.024227 139811861223232 checkpoints.py:356] Saving checkpoint at step: 3488 +I0418 13:31:43.235592 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_3488 +I0418 13:31:43.236161 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_3488. +I0418 13:31:44.187038 139561623590656 logging_writer.py:48] [3500] global_step=3500, grad_norm=0.08530890941619873, loss=0.2989664375782013 +I0418 13:32:07.555382 139561657161472 logging_writer.py:48] [3600] global_step=3600, grad_norm=0.07605163753032684, loss=0.22411733865737915 +I0418 13:32:30.977212 139561623590656 logging_writer.py:48] [3700] global_step=3700, grad_norm=0.12153548747301102, loss=0.2602514922618866 +I0418 13:32:54.498524 139561657161472 logging_writer.py:48] [3800] global_step=3800, grad_norm=0.0802716538310051, loss=0.33755627274513245 +I0418 13:33:03.404680 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:33:04.807506 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:33:06.156959 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:33:07.507241 139811861223232 submission_runner.py:406] Time since start: 1292.73s, Step: 3839, {'train/ssim': 0.7396433012826102, 'train/loss': 0.2715693882533482, 'validation/ssim': 0.7179693545125211, 'validation/loss': 0.2897802899242051, 'validation/num_examples': 3554, 'test/ssim': 0.7351546683014522, 'test/loss': 0.29129910355871264, 'test/num_examples': 3581, 'score': 1026.8801860809326, 'total_duration': 1292.725922346115, 'accumulated_submission_time': 1026.8801860809326, 'accumulated_eval_time': 262.3833680152893, 'accumulated_logging_time': 3.4143550395965576} +I0418 13:33:07.515569 139561623590656 logging_writer.py:48] [3839] accumulated_eval_time=262.383368, accumulated_logging_time=3.414355, accumulated_submission_time=1026.880186, global_step=3839, preemption_count=0, score=1026.880186, test/loss=0.291299, test/num_examples=3581, test/ssim=0.735155, total_duration=1292.725922, train/loss=0.271569, train/ssim=0.739643, validation/loss=0.289780, validation/num_examples=3554, validation/ssim=0.717969 +I0418 13:33:07.551471 139811861223232 checkpoints.py:356] Saving checkpoint at step: 3839 +I0418 13:33:07.757831 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_3839 +I0418 13:33:07.758429 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_3839. +I0418 13:33:20.033087 139561657161472 logging_writer.py:48] [3900] global_step=3900, grad_norm=0.3017250597476959, loss=0.30064845085144043 +I0418 13:33:43.710441 139561648768768 logging_writer.py:48] [4000] global_step=4000, grad_norm=0.1491791307926178, loss=0.34225165843963623 +I0418 13:34:07.239346 139561657161472 logging_writer.py:48] [4100] global_step=4100, grad_norm=0.15506602823734283, loss=0.27582505345344543 +I0418 13:34:27.999338 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:34:29.404618 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:34:30.758607 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:34:32.112137 139811861223232 submission_runner.py:406] Time since start: 1377.33s, Step: 4189, {'train/ssim': 0.7408084869384766, 'train/loss': 0.27250548771449495, 'validation/ssim': 0.7184662912827097, 'validation/loss': 0.2907320880192213, 'validation/num_examples': 3554, 'test/ssim': 0.7357353289278483, 'test/loss': 0.292204421447396, 'test/num_examples': 3581, 'score': 1107.116602897644, 'total_duration': 1377.3308203220367, 'accumulated_submission_time': 1107.116602897644, 'accumulated_eval_time': 266.4961130619049, 'accumulated_logging_time': 3.6657447814941406} +I0418 13:34:32.120835 139561648768768 logging_writer.py:48] [4189] accumulated_eval_time=266.496113, accumulated_logging_time=3.665745, accumulated_submission_time=1107.116603, global_step=4189, preemption_count=0, score=1107.116603, test/loss=0.292204, test/num_examples=3581, test/ssim=0.735735, total_duration=1377.330820, train/loss=0.272505, train/ssim=0.740808, validation/loss=0.290732, validation/num_examples=3554, validation/ssim=0.718466 +I0418 13:34:32.157636 139811861223232 checkpoints.py:356] Saving checkpoint at step: 4189 +I0418 13:34:32.370441 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_4189 +I0418 13:34:32.371031 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_4189. +I0418 13:34:33.253151 139561657161472 logging_writer.py:48] [4200] global_step=4200, grad_norm=0.12261800467967987, loss=0.24371099472045898 +I0418 13:34:56.732465 139561631983360 logging_writer.py:48] [4300] global_step=4300, grad_norm=0.23392254114151, loss=0.3791535198688507 +I0418 13:35:20.388804 139561657161472 logging_writer.py:48] [4400] global_step=4400, grad_norm=0.07827229052782059, loss=0.3365706503391266 +I0418 13:35:44.055584 139561631983360 logging_writer.py:48] [4500] global_step=4500, grad_norm=0.06862630695104599, loss=0.2232666015625 +I0418 13:35:52.438558 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:35:53.844915 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:35:55.198246 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:35:56.546679 139811861223232 submission_runner.py:406] Time since start: 1461.77s, Step: 4537, {'train/ssim': 0.7422135898045131, 'train/loss': 0.27194954667772564, 'validation/ssim': 0.7192344343521384, 'validation/loss': 0.2907137122124719, 'validation/num_examples': 3554, 'test/ssim': 0.7364333215887671, 'test/loss': 0.29247723035857653, 'test/num_examples': 3581, 'score': 1187.179518699646, 'total_duration': 1461.7653725147247, 'accumulated_submission_time': 1187.179518699646, 'accumulated_eval_time': 270.6042070388794, 'accumulated_logging_time': 3.9248924255371094} +I0418 13:35:56.555123 139561657161472 logging_writer.py:48] [4537] accumulated_eval_time=270.604207, accumulated_logging_time=3.924892, accumulated_submission_time=1187.179519, global_step=4537, preemption_count=0, score=1187.179519, test/loss=0.292477, test/num_examples=3581, test/ssim=0.736433, total_duration=1461.765373, train/loss=0.271950, train/ssim=0.742214, validation/loss=0.290714, validation/num_examples=3554, validation/ssim=0.719234 +I0418 13:35:56.590492 139811861223232 checkpoints.py:356] Saving checkpoint at step: 4537 +I0418 13:35:56.799838 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_4537 +I0418 13:35:56.800457 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_4537. +I0418 13:36:09.590547 139561631983360 logging_writer.py:48] [4600] global_step=4600, grad_norm=0.05950324982404709, loss=0.290679395198822 +I0418 13:36:32.649566 139561522943744 logging_writer.py:48] [4700] global_step=4700, grad_norm=0.10117150843143463, loss=0.2727523148059845 +I0418 13:36:56.109454 139561631983360 logging_writer.py:48] [4800] global_step=4800, grad_norm=0.12309190630912781, loss=0.24177217483520508 +I0418 13:37:16.914435 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:37:18.321350 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:37:19.670757 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:37:21.022195 139811861223232 submission_runner.py:406] Time since start: 1546.24s, Step: 4891, {'train/ssim': 0.7442799295697894, 'train/loss': 0.2703608615057809, 'validation/ssim': 0.7221035332108188, 'validation/loss': 0.2888360825830051, 'validation/num_examples': 3554, 'test/ssim': 0.7392860376378805, 'test/loss': 0.2903330402645909, 'test/num_examples': 3581, 'score': 1267.2888889312744, 'total_duration': 1546.2408757209778, 'accumulated_submission_time': 1267.2888889312744, 'accumulated_eval_time': 274.71193289756775, 'accumulated_logging_time': 4.178853750228882} +I0418 13:37:21.030767 139561522943744 logging_writer.py:48] [4891] accumulated_eval_time=274.711933, accumulated_logging_time=4.178854, accumulated_submission_time=1267.288889, global_step=4891, preemption_count=0, score=1267.288889, test/loss=0.290333, test/num_examples=3581, test/ssim=0.739286, total_duration=1546.240876, train/loss=0.270361, train/ssim=0.744280, validation/loss=0.288836, validation/num_examples=3554, validation/ssim=0.722104 +I0418 13:37:21.066384 139811861223232 checkpoints.py:356] Saving checkpoint at step: 4891 +I0418 13:37:21.278541 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_4891 +I0418 13:37:21.279094 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_4891. +I0418 13:37:22.013600 139561631983360 logging_writer.py:48] [4900] global_step=4900, grad_norm=0.06771455705165863, loss=0.36116886138916016 +I0418 13:37:45.107243 139561514551040 logging_writer.py:48] [5000] global_step=5000, grad_norm=0.07337518036365509, loss=0.3133208155632019 +I0418 13:38:08.456659 139561631983360 logging_writer.py:48] [5100] global_step=5100, grad_norm=0.092543825507164, loss=0.23170822858810425 +I0418 13:38:31.665441 139561514551040 logging_writer.py:48] [5200] global_step=5200, grad_norm=0.13452361524105072, loss=0.26087498664855957 +I0418 13:38:41.343533 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:38:42.745579 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:38:44.098957 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:38:45.448367 139811861223232 submission_runner.py:406] Time since start: 1630.67s, Step: 5242, {'train/ssim': 0.7431642668587821, 'train/loss': 0.2701894555773054, 'validation/ssim': 0.7207614467193655, 'validation/loss': 0.2887529277640511, 'validation/num_examples': 3554, 'test/ssim': 0.7380836740348367, 'test/loss': 0.29012591956681094, 'test/num_examples': 3581, 'score': 1347.3487265110016, 'total_duration': 1630.667048215866, 'accumulated_submission_time': 1347.3487265110016, 'accumulated_eval_time': 278.81671237945557, 'accumulated_logging_time': 4.436016082763672} +I0418 13:38:45.456939 139561631983360 logging_writer.py:48] [5242] accumulated_eval_time=278.816712, accumulated_logging_time=4.436016, accumulated_submission_time=1347.348727, global_step=5242, preemption_count=0, score=1347.348727, test/loss=0.290126, test/num_examples=3581, test/ssim=0.738084, total_duration=1630.667048, train/loss=0.270189, train/ssim=0.743164, validation/loss=0.288753, validation/num_examples=3554, validation/ssim=0.720761 +I0418 13:38:45.493216 139811861223232 checkpoints.py:356] Saving checkpoint at step: 5242 +I0418 13:38:45.712413 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_5242 +I0418 13:38:45.712977 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_5242. +I0418 13:38:57.652769 139561514551040 logging_writer.py:48] [5300] global_step=5300, grad_norm=0.21728134155273438, loss=0.22996574640274048 +I0418 13:39:21.179472 139561506158336 logging_writer.py:48] [5400] global_step=5400, grad_norm=0.20460711419582367, loss=0.27934032678604126 +I0418 13:39:27.604871 139811861223232 spec.py:298] Evaluating on the training split. +I0418 13:39:29.013020 139811861223232 spec.py:310] Evaluating on the validation split. +I0418 13:39:30.366590 139811861223232 spec.py:326] Evaluating on the test split. +I0418 13:39:31.717215 139811861223232 submission_runner.py:406] Time since start: 1676.94s, Step: 5428, {'train/ssim': 0.7416911125183105, 'train/loss': 0.2704659700393677, 'validation/ssim': 0.719704442837296, 'validation/loss': 0.2888820049262275, 'validation/num_examples': 3554, 'test/ssim': 0.7369207847188285, 'test/loss': 0.2903133372094038, 'test/num_examples': 3581, 'score': 1389.2380073070526, 'total_duration': 1676.9359061717987, 'accumulated_submission_time': 1389.2380073070526, 'accumulated_eval_time': 282.9290335178375, 'accumulated_logging_time': 4.700815200805664} +I0418 13:39:31.726211 139561514551040 logging_writer.py:48] [5428] accumulated_eval_time=282.929034, accumulated_logging_time=4.700815, accumulated_submission_time=1389.238007, global_step=5428, preemption_count=0, score=1389.238007, test/loss=0.290313, test/num_examples=3581, test/ssim=0.736921, total_duration=1676.935906, train/loss=0.270466, train/ssim=0.741691, validation/loss=0.288882, validation/num_examples=3554, validation/ssim=0.719704 +I0418 13:39:31.762403 139811861223232 checkpoints.py:356] Saving checkpoint at step: 5428 +I0418 13:39:31.974781 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_5428 +I0418 13:39:31.975349 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_5428. +I0418 13:39:31.982997 139561506158336 logging_writer.py:48] [5428] global_step=5428, preemption_count=0, score=1389.238007 +I0418 13:39:32.013861 139811861223232 checkpoints.py:356] Saving checkpoint at step: 5428 +I0418 13:39:32.309222 139811861223232 checkpoints.py:317] Saved checkpoint at /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_5428 +I0418 13:39:32.309746 139811861223232 checkpoint_utils.py:240] Saved checkpoint to /experiment_runs/timing_v3/timing_adamw/fastmri_jax/trial_1/checkpoint_5428. +I0418 13:39:33.007552 139811861223232 submission_runner.py:567] Tuning trial 1/1 +I0418 13:39:33.007801 139811861223232 submission_runner.py:568] Hyperparameters: Hyperparameters(learning_rate=0.0019814680146414726, one_minus_beta1=0.22838767981804783, beta2=0.999, warmup_factor=0.05, weight_decay=0.010340635370188849, label_smoothing=0.1, dropout_rate=0.0) +I0418 13:39:33.012595 139811861223232 submission_runner.py:569] Metrics: {'eval_results': [(1, {'train/ssim': 0.21045025757380895, 'train/loss': 0.8323062488010952, 'validation/ssim': 0.20073387820061903, 'validation/loss': 0.8397238092202448, 'validation/num_examples': 3554, 'test/ssim': 0.22385435256169714, 'test/loss': 0.8394407577535954, 'test/num_examples': 3581, 'score': 64.92480111122131, 'total_duration': 277.60004019737244, 'accumulated_submission_time': 64.92480111122131, 'accumulated_eval_time': 212.67506647109985, 'accumulated_logging_time': 0, 'global_step': 1, 'preemption_count': 0}), (331, {'train/ssim': 0.7070232800074986, 'train/loss': 0.30231707436697824, 'validation/ssim': 0.6850338636923186, 'validation/loss': 0.32094466209948297, 'validation/num_examples': 3554, 'test/ssim': 0.7031452484684795, 'test/loss': 0.32283779763203363, 'test/num_examples': 3581, 'score': 145.01519441604614, 'total_duration': 362.53429913520813, 'accumulated_submission_time': 145.01519441604614, 'accumulated_eval_time': 217.22807550430298, 'accumulated_logging_time': 0.28686976432800293, 'global_step': 331, 'preemption_count': 0}), (560, {'train/ssim': 0.7200777871268136, 'train/loss': 0.290642329624721, 'validation/ssim': 0.6985314604802687, 'validation/loss': 0.3087669977929797, 'validation/num_examples': 3554, 'test/ssim': 0.7159312359283371, 'test/loss': 0.31089522365217465, 'test/num_examples': 3581, 'score': 225.29421615600586, 'total_duration': 447.26284670829773, 'accumulated_submission_time': 225.29421615600586, 'accumulated_eval_time': 221.3283338546753, 'accumulated_logging_time': 0.633378267288208, 'global_step': 560, 'preemption_count': 0}), (787, {'train/ssim': 0.7264369555882045, 'train/loss': 0.2863526003701346, 'validation/ssim': 0.7047054567696609, 'validation/loss': 0.30473246061875703, 'validation/num_examples': 3554, 'test/ssim': 0.7218843538379642, 'test/loss': 0.30675468455476823, 'test/num_examples': 3581, 'score': 305.5884253978729, 'total_duration': 531.9939126968384, 'accumulated_submission_time': 305.5884253978729, 'accumulated_eval_time': 225.4286892414093, 'accumulated_logging_time': 0.9671237468719482, 'global_step': 787, 'preemption_count': 0}), (1034, {'train/ssim': 0.7311224256243024, 'train/loss': 0.281079752104623, 'validation/ssim': 0.7095664243853756, 'validation/loss': 0.298722438350714, 'validation/num_examples': 3554, 'test/ssim': 0.7266145187927604, 'test/loss': 0.30085532394277087, 'test/num_examples': 3581, 'score': 385.6187229156494, 'total_duration': 616.4704086780548, 'accumulated_submission_time': 385.6187229156494, 'accumulated_eval_time': 229.52676105499268, 'accumulated_logging_time': 1.3123102188110352, 'global_step': 1034, 'preemption_count': 0}), (1384, {'train/ssim': 0.7352303096226284, 'train/loss': 0.278595038822719, 'validation/ssim': 0.7136715452658976, 'validation/loss': 0.29653461823385624, 'validation/num_examples': 3554, 'test/ssim': 0.7307323891502024, 'test/loss': 0.2984414610640359, 'test/num_examples': 3581, 'score': 465.86059069633484, 'total_duration': 701.1074573993683, 'accumulated_submission_time': 465.86059069633484, 'accumulated_eval_time': 233.62830877304077, 'accumulated_logging_time': 1.601564884185791, 'global_step': 1384, 'preemption_count': 0}), (1735, {'train/ssim': 0.7361899784633091, 'train/loss': 0.27769272668021067, 'validation/ssim': 0.7143108859515687, 'validation/loss': 0.29573109744302195, 'validation/num_examples': 3554, 'test/ssim': 0.7314950132862678, 'test/loss': 0.29746227376256634, 'test/num_examples': 3581, 'score': 545.9810972213745, 'total_duration': 785.6000730991364, 'accumulated_submission_time': 545.9810972213745, 'accumulated_eval_time': 237.73684239387512, 'accumulated_logging_time': 1.8607571125030518, 'global_step': 1735, 'preemption_count': 0}), (2085, {'train/ssim': 0.736290454864502, 'train/loss': 0.2762721947261265, 'validation/ssim': 0.7143876178249859, 'validation/loss': 0.29404577835933104, 'validation/num_examples': 3554, 'test/ssim': 0.7318026945598296, 'test/loss': 0.2957445968632191, 'test/num_examples': 3581, 'score': 626.1339099407196, 'total_duration': 870.1193282604218, 'accumulated_submission_time': 626.1339099407196, 'accumulated_eval_time': 241.8417820930481, 'accumulated_logging_time': 2.117722988128662, 'global_step': 2085, 'preemption_count': 0}), (2436, {'train/ssim': 0.7389028412955148, 'train/loss': 0.2739771604537964, 'validation/ssim': 0.7167958445809299, 'validation/loss': 0.29225274613551633, 'validation/num_examples': 3554, 'test/ssim': 0.7340429796582658, 'test/loss': 0.293848228933957, 'test/num_examples': 3581, 'score': 706.2287080287933, 'total_duration': 954.5820419788361, 'accumulated_submission_time': 706.2287080287933, 'accumulated_eval_time': 245.9472210407257, 'accumulated_logging_time': 2.375852584838867, 'global_step': 2436, 'preemption_count': 0}), (2787, {'train/ssim': 0.740485531943185, 'train/loss': 0.27365691321236746, 'validation/ssim': 0.719278536288337, 'validation/loss': 0.2916071542408202, 'validation/num_examples': 3554, 'test/ssim': 0.7363298294165387, 'test/loss': 0.29325788722162105, 'test/num_examples': 3581, 'score': 786.4194860458374, 'total_duration': 1039.1484916210175, 'accumulated_submission_time': 786.4194860458374, 'accumulated_eval_time': 250.05729293823242, 'accumulated_logging_time': 2.636970281600952, 'global_step': 2787, 'preemption_count': 0}), (3137, {'train/ssim': 0.7372136116027832, 'train/loss': 0.276213492665972, 'validation/ssim': 0.7147835735263084, 'validation/loss': 0.29487282705226503, 'validation/num_examples': 3554, 'test/ssim': 0.7319260261405682, 'test/loss': 0.29648540446758936, 'test/num_examples': 3581, 'score': 866.6045970916748, 'total_duration': 1123.708041191101, 'accumulated_submission_time': 866.6045970916748, 'accumulated_eval_time': 254.16952347755432, 'accumulated_logging_time': 2.894801139831543, 'global_step': 3137, 'preemption_count': 0}), (3488, {'train/ssim': 0.7427416528974261, 'train/loss': 0.2716409649167742, 'validation/ssim': 0.7203257168419387, 'validation/loss': 0.2901462947778911, 'validation/num_examples': 3554, 'test/ssim': 0.7374846738864842, 'test/loss': 0.2917485581999616, 'test/num_examples': 3581, 'score': 946.7163908481598, 'total_duration': 1208.1957819461823, 'accumulated_submission_time': 946.7163908481598, 'accumulated_eval_time': 258.28086018562317, 'accumulated_logging_time': 3.155013084411621, 'global_step': 3488, 'preemption_count': 0}), (3839, {'train/ssim': 0.7396433012826102, 'train/loss': 0.2715693882533482, 'validation/ssim': 0.7179693545125211, 'validation/loss': 0.2897802899242051, 'validation/num_examples': 3554, 'test/ssim': 0.7351546683014522, 'test/loss': 0.29129910355871264, 'test/num_examples': 3581, 'score': 1026.8801860809326, 'total_duration': 1292.725922346115, 'accumulated_submission_time': 1026.8801860809326, 'accumulated_eval_time': 262.3833680152893, 'accumulated_logging_time': 3.4143550395965576, 'global_step': 3839, 'preemption_count': 0}), (4189, {'train/ssim': 0.7408084869384766, 'train/loss': 0.27250548771449495, 'validation/ssim': 0.7184662912827097, 'validation/loss': 0.2907320880192213, 'validation/num_examples': 3554, 'test/ssim': 0.7357353289278483, 'test/loss': 0.292204421447396, 'test/num_examples': 3581, 'score': 1107.116602897644, 'total_duration': 1377.3308203220367, 'accumulated_submission_time': 1107.116602897644, 'accumulated_eval_time': 266.4961130619049, 'accumulated_logging_time': 3.6657447814941406, 'global_step': 4189, 'preemption_count': 0}), (4537, {'train/ssim': 0.7422135898045131, 'train/loss': 0.27194954667772564, 'validation/ssim': 0.7192344343521384, 'validation/loss': 0.2907137122124719, 'validation/num_examples': 3554, 'test/ssim': 0.7364333215887671, 'test/loss': 0.29247723035857653, 'test/num_examples': 3581, 'score': 1187.179518699646, 'total_duration': 1461.7653725147247, 'accumulated_submission_time': 1187.179518699646, 'accumulated_eval_time': 270.6042070388794, 'accumulated_logging_time': 3.9248924255371094, 'global_step': 4537, 'preemption_count': 0}), (4891, {'train/ssim': 0.7442799295697894, 'train/loss': 0.2703608615057809, 'validation/ssim': 0.7221035332108188, 'validation/loss': 0.2888360825830051, 'validation/num_examples': 3554, 'test/ssim': 0.7392860376378805, 'test/loss': 0.2903330402645909, 'test/num_examples': 3581, 'score': 1267.2888889312744, 'total_duration': 1546.2408757209778, 'accumulated_submission_time': 1267.2888889312744, 'accumulated_eval_time': 274.71193289756775, 'accumulated_logging_time': 4.178853750228882, 'global_step': 4891, 'preemption_count': 0}), (5242, {'train/ssim': 0.7431642668587821, 'train/loss': 0.2701894555773054, 'validation/ssim': 0.7207614467193655, 'validation/loss': 0.2887529277640511, 'validation/num_examples': 3554, 'test/ssim': 0.7380836740348367, 'test/loss': 0.29012591956681094, 'test/num_examples': 3581, 'score': 1347.3487265110016, 'total_duration': 1630.667048215866, 'accumulated_submission_time': 1347.3487265110016, 'accumulated_eval_time': 278.81671237945557, 'accumulated_logging_time': 4.436016082763672, 'global_step': 5242, 'preemption_count': 0}), (5428, {'train/ssim': 0.7416911125183105, 'train/loss': 0.2704659700393677, 'validation/ssim': 0.719704442837296, 'validation/loss': 0.2888820049262275, 'validation/num_examples': 3554, 'test/ssim': 0.7369207847188285, 'test/loss': 0.2903133372094038, 'test/num_examples': 3581, 'score': 1389.2380073070526, 'total_duration': 1676.9359061717987, 'accumulated_submission_time': 1389.2380073070526, 'accumulated_eval_time': 282.9290335178375, 'accumulated_logging_time': 4.700815200805664, 'global_step': 5428, 'preemption_count': 0})], 'global_step': 5428} +I0418 13:39:33.012728 139811861223232 submission_runner.py:570] Timing: 1389.2380073070526 +I0418 13:39:33.012772 139811861223232 submission_runner.py:571] ==================== +I0418 13:39:33.012883 139811861223232 submission_runner.py:631] Final fastmri score: 1389.2380073070526 diff --git a/scoring/test_scoring_utils.py b/scoring/test_scoring_utils.py new file mode 100644 index 000000000..597525386 --- /dev/null +++ b/scoring/test_scoring_utils.py @@ -0,0 +1,27 @@ +from absl.testing import absltest +import scoring_utils + +TEST_LOGFILE = 'test_data/trial_0/adamw_fastmri_jax_04-18-2023-13-10-58.log' +NUM_EVALS = 18 + + +class Test(absltest.TestCase): + + def test_get_trials_dict(self): + trials_dict = scoring_utils.get_trials_dict(TEST_LOGFILE) + self.assertEqual(len(trials_dict['1']['global_step']), NUM_EVALS) + + def test_get_trials_df_dict(self): + trials_dict = scoring_utils.get_trials_df_dict(TEST_LOGFILE) + for trial in trials_dict: + df = trials_dict[trial] + self.assertEqual(len(df.index), NUM_EVALS) + + def test_get_trials_df(self): + df = scoring_utils.get_trials_df(TEST_LOGFILE) + for column in df.columns: + self.assertEqual(len(df.at['1', column]), NUM_EVALS) + + +if __name__ == '__main__': + absltest.main()