From dd6b835d0a3447903590f965b623b129814b84e6 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Wed, 21 Aug 2024 03:38:08 -0700 Subject: [PATCH] Revert "Adjust multifun" This reverts commit c7a2f15f010d22796ee75426a50982a6ba451324. Work around https://github.com/lebrice/SimpleParsing/issues/322 --- gbmi/exp_multifun/train.py | 199 ++++++++++++++++++++++++++++++++++--- 1 file changed, 186 insertions(+), 13 deletions(-) diff --git a/gbmi/exp_multifun/train.py b/gbmi/exp_multifun/train.py index 0516dba6..4766ff58 100644 --- a/gbmi/exp_multifun/train.py +++ b/gbmi/exp_multifun/train.py @@ -5,10 +5,9 @@ import sys from dataclasses import dataclass, field from functools import cache -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Literal, Optional, Sequence, Tuple, Union import numpy as np -import simple_parsing import torch import torch.nn.functional as F from jaxtyping import Bool, Float, Integer @@ -595,26 +594,200 @@ def test_dataloader(self): return DataLoader(self.data_test, batch_size=self.config.batch_size) -def main(argv: List[str] = sys.argv): - parser = simple_parsing.ArgumentParser( +def config_of_argv(argv=sys.argv) -> tuple[Config[Multifun], dict]: + parser = argparse.ArgumentParser( description="Train a model with configurable attention rate." ) - parser.add_arguments( - Multifun, dest="experiment_config", default=MULTIFUN_OF_2_CONFIG.experiment - ) add_force_argument(parser) add_no_save_argument(parser) - Config.add_arguments(parser, default=MULTIFUN_OF_2_CONFIG) - + # add --K N argument accepting 2 and 10 + parser.add_argument( + "--K", + metavar="K", + type=int, + default=10, + help="The length of the list to take the reduction of.", + ) + parser.add_argument( + "--func", + metavar="FUNC", + type=str, + nargs="+", + default=["max", "min"], + help="The functions to apply to the list.", + ) + parser.add_argument( + "--force-adjacent-gap", + metavar="K", + type=str, + action="append", + help="For --K 2, include all sequences (n, n±K) in training set. Accepts int and comma-separated-list.", + ) + parser.add_argument( + "--training-ratio", + type=float, + default=0.7, + help="For --K 2, the fraction of sequences to include in training.", + ) + parser.add_argument( + "--use-log1p", + action=argparse.BooleanOptionalAction, + default=False, + help="Use a more accurate implementation of log_softmax.", + ) + parser.add_argument( + "--use-end-of-sequence", + action=argparse.BooleanOptionalAction, + default=False, + help="Use an end-of-sequence token so the query-side attention vector is fixed.", + ) + parser.add_argument("--weight-decay", type=float, default=None, help="Weight decay") + parser.add_argument( + "--optimizer", + choices=["Adam", "AdamW", "SGD"], + default="Adam", + help="The optimizer to use.", + ) + parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") + parser.add_argument( + "--betas", + type=float, + nargs=2, + default=(0.9, 0.999), + help="coefficients used for computing running averages of gradient and its square", + ) + parser.add_argument( + "--summary-slug-extra", type=str, default="", help="Extra model description" + ) + parser.add_argument( + "--pick-max-first", + action=argparse.BooleanOptionalAction, + default=False, + help="Pick the maximum value first, then fill in the rest of the sequence. Only meaningful for --K N > 2.", + ) + parser.add_argument( + "--use-kaiming-init", + action=argparse.BooleanOptionalAction, + default=False, + help="Use torch.nn.init.kaiming_uniform_, rather than HookedTransformer's init.", + ) + parser.add_argument( + "--log-matrix-interp", + action=argparse.BooleanOptionalAction, + default=False, + help="Log matrices every train step", + ) + parser.add_argument( + "--checkpoint-matrix-interp", + action=argparse.BooleanOptionalAction, + default=False, + help="Log matrices for checkpointing", + ) + parser.add_argument( + "--log-final-matrix-interp", + action=argparse.BooleanOptionalAction, + default=True, + help="Log matrices after training", + ) + HOOKED_TRANSFORMER_CONFIG_ARGS = set( + ( + "normalization_type", + "d_model", + "d_head", + "n_layers", + "n_heads", + "d_vocab", + "dtype", + "eps", + ) + ) + Config.add_arguments(parser) + add_HookedTransformerConfig_arguments(parser, HOOKED_TRANSFORMER_CONFIG_ARGS) args = parser.parse_args(argv[1:]) - config = Config(args.experiment_config) - config = config.update_from_args(args) - print("Model config:", MultifunTrainingWrapper.build_model(config).cfg) + config = set_params( + (MULTIFUN_OF_2_CONFIG if args.K <= 2 else MULTIFUN_OF_10_SINGLE_CONFIG), + { + ("experiment", "seq_len"): args.K, + ("experiment", "funcs"): tuple(args.func), + ("experiment", "use_end_of_sequence"): args.use_end_of_sequence, + ("experiment", "use_log1p"): args.use_log1p, + ("experiment", "optimizer"): args.optimizer, + ("experiment", "summary_slug_extra"): args.summary_slug_extra, + ("experiment", "train_dataset_cfg", "pick_max_first"): args.pick_max_first, + ("experiment", "logging_options"): ModelMatrixLoggingOptions.all(), + ("experiment", "log_matrix_on_run_batch_prefixes"): set() + | ({"test_"} if args.log_final_matrix_interp else set()) + | ({"periodic_test_"} if args.checkpoint_matrix_interp else set()) + | ({""} if args.log_matrix_interp else set()), + }, + ).update_from_args(args) + config.experiment = MultifunTrainingWrapper.update_config_from_model_config( + config.experiment, + update_HookedTransformerConfig_from_args( + config, + MultifunTrainingWrapper.build_model_config(config), + args, + HOOKED_TRANSFORMER_CONFIG_ARGS, + ), + ) + config.experiment.__post_init__() # for seq_len, d_vocab + if args.weight_decay is not None: + config.experiment.optimizer_kwargs["weight_decay"] = args.weight_decay + config.experiment.optimizer_kwargs.update( + {"lr": args.lr, "betas": tuple(args.betas)} + ) + if args.argmax_of <= 2: + if args.force_adjacent_gap: + force_adjacent = tuple( + sorted( + set( + int(k.strip()) + for s in args.force_adjacent_gap + for k in s.split(",") + ) + ) + ) + config = set_params( + config, + { + ( + "experiment", + "train_dataset_cfg", + "force_adjacent", + ): force_adjacent, + ( + "experiment", + "test_dataset_cfg", + "force_adjacent", + ): force_adjacent, + }, + ) + config = set_params( + config, + { + ( + "experiment", + "train_dataset_cfg", + "training_ratio", + ): args.training_ratio, + ( + "experiment", + "test_dataset_cfg", + "training_ratio", + ): args.training_ratio, + }, + ) + return config, dict(force=args.force, save_to=args.save_to) + + +def main(argv=sys.argv): + config, kwargs = config_of_argv(argv) print("Training model:", config) - train_or_load_model(config, force=args.force, save_to=args.save_to) + return train_or_load_model(config, **kwargs) +# %% if __name__ == "__main__": main()