diff --git a/scripts/train.py b/scripts/train.py index d936083b..f81f7553 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -139,6 +139,9 @@ def process_batch(batch): del batch["dataset_name"] return batch + # copy the original config before we modify it + model_config = FLAGS.config.to_dict() + # load datasets if "oxe_kwargs" in FLAGS.config.dataset_kwargs: # create dataset_kwargs_list from oxe_kwargs @@ -180,7 +183,7 @@ def process_batch(batch): rng = jax.random.PRNGKey(FLAGS.config.seed) rng, init_rng = jax.random.split(rng) model = OctoModel.from_config( - FLAGS.config.to_dict(), + model_config, example_batch, text_processor, verbose=True,