From e42b6f700d0268b2316827ecf5d65db2c0064f20 Mon Sep 17 00:00:00 2001 From: sychen52 Date: Tue, 19 Nov 2024 17:22:31 +0000 Subject: [PATCH] Add llama 3 tokenizer --- .../fuji-1B-v3-flash-single-host.txt | 6 +- .../fuji-1B-v3-flash-single-host_init.txt | 2 +- .../fuji-1B-v3-flash-tiktoken-single-host.txt | 287 ++++++++++++++ ...-1B-v3-flash-tiktoken-single-host_init.txt | 9 + ...flash-tiktoken-single-host_regularizer.txt | 10 + .../fuji-1B-v3-flash-tiktoken.txt | 287 ++++++++++++++ .../fuji-1B-v3-flash-tiktoken_init.txt | 9 + .../fuji-1B-v3-flash-tiktoken_regularizer.txt | 10 + .../fuji-1B-v3-flash.txt | 6 +- .../fuji-1B-v3-flash_init.txt | 2 +- .../fuji-1B-v3-single-host.txt | 6 +- .../fuji-1B-v3-single-host_init.txt | 2 +- .../fuji-1B-v3-tiktoken-single-host.txt | 252 +++++++++++++ .../fuji-1B-v3-tiktoken-single-host_init.txt | 9 + ...1B-v3-tiktoken-single-host_regularizer.txt | 10 + .../fuji-1B-v3-tiktoken.txt | 252 +++++++++++++ .../fuji-1B-v3-tiktoken_init.txt | 9 + .../fuji-1B-v3-tiktoken_regularizer.txt | 10 + .../fuji-1B-v3.txt | 6 +- .../fuji-1B-v3_init.txt | 2 +- .../fuji-3B-v3-flash-single-host.txt | 6 +- .../fuji-3B-v3-flash-single-host_init.txt | 2 +- .../fuji-3B-v3-flash-tiktoken-single-host.txt | 287 ++++++++++++++ ...-3B-v3-flash-tiktoken-single-host_init.txt | 9 + ...flash-tiktoken-single-host_regularizer.txt | 10 + .../fuji-3B-v3-flash-tiktoken.txt | 287 ++++++++++++++ .../fuji-3B-v3-flash-tiktoken_init.txt | 9 + .../fuji-3B-v3-flash-tiktoken_regularizer.txt | 10 + .../fuji-3B-v3-flash.txt | 6 +- .../fuji-3B-v3-flash_init.txt | 2 +- .../fuji-3B-v3-single-host.txt | 6 +- .../fuji-3B-v3-single-host_init.txt | 2 +- .../fuji-3B-v3-tiktoken-single-host.txt | 252 +++++++++++++ .../fuji-3B-v3-tiktoken-single-host_init.txt | 9 + ...3B-v3-tiktoken-single-host_regularizer.txt | 10 + .../fuji-3B-v3-tiktoken.txt | 252 +++++++++++++ .../fuji-3B-v3-tiktoken_init.txt | 9 + .../fuji-3B-v3-tiktoken_regularizer.txt | 10 + .../fuji-3B-v3.txt | 6 +- .../fuji-3B-v3_init.txt | 2 +- .../fuji-70B-v1-flash-single-host.txt | 311 +++++++++++++++ .../fuji-70B-v1-flash-single-host_init.txt | 10 + ...i-70B-v1-flash-single-host_regularizer.txt | 11 + .../fuji-70B-v1-single-host.txt | 276 ++++++++++++++ .../fuji-70B-v1-single-host_init.txt | 10 + .../fuji-70B-v1-single-host_regularizer.txt | 11 + .../fuji-70B-v2-flash-single-host.txt | 312 +++++++++++++++ .../fuji-70B-v2-flash-single-host_init.txt | 10 + ...i-70B-v2-flash-single-host_regularizer.txt | 11 + .../fuji-70B-v2-single-host.txt | 277 ++++++++++++++ .../fuji-70B-v2-single-host_init.txt | 10 + .../fuji-70B-v2-single-host_regularizer.txt | 11 + .../fuji-70B-v3-flash-single-host.txt | 312 +++++++++++++++ .../fuji-70B-v3-flash-single-host_init.txt | 10 + ...i-70B-v3-flash-single-host_regularizer.txt | 11 + .../fuji-70B-v3-flash-tiktoken.txt | 313 +++++++++++++++ .../fuji-70B-v3-flash-tiktoken_init.txt | 10 + ...fuji-70B-v3-flash-tiktoken_regularizer.txt | 11 + .../fuji-70B-v3-flash.txt | 6 +- .../fuji-70B-v3-flash_init.txt | 4 +- .../fuji-70B-v3-single-host.txt | 277 ++++++++++++++ .../fuji-70B-v3-single-host_init.txt | 10 + .../fuji-70B-v3-single-host_regularizer.txt | 11 + .../fuji-70B-v3-tiktoken.txt | 278 ++++++++++++++ .../fuji-70B-v3-tiktoken_init.txt | 10 + .../fuji-70B-v3-tiktoken_regularizer.txt | 11 + .../fuji-70B-v3.txt | 6 +- .../fuji-70B-v3_init.txt | 4 +- .../fuji-8B-v3-flash-single-host.txt | 6 +- .../fuji-8B-v3-flash-single-host_init.txt | 4 +- .../fuji-8B-v3-flash-tiktoken-single-host.txt | 356 ++++++++++++++++++ ...-8B-v3-flash-tiktoken-single-host_init.txt | 10 + ...flash-tiktoken-single-host_regularizer.txt | 11 + .../fuji-8B-v3-flash-tiktoken.txt | 356 ++++++++++++++++++ .../fuji-8B-v3-flash-tiktoken_init.txt | 10 + .../fuji-8B-v3-flash-tiktoken_regularizer.txt | 11 + .../fuji-8B-v3-flash.txt | 6 +- .../fuji-8B-v3-flash_init.txt | 4 +- .../fuji-8B-v3-single-host.txt | 6 +- .../fuji-8B-v3-single-host_init.txt | 4 +- .../fuji-8B-v3-tiktoken-single-host.txt | 321 ++++++++++++++++ .../fuji-8B-v3-tiktoken-single-host_init.txt | 10 + ...8B-v3-tiktoken-single-host_regularizer.txt | 11 + .../fuji-8B-v3-tiktoken.txt | 321 ++++++++++++++++ .../fuji-8B-v3-tiktoken_init.txt | 10 + .../fuji-8B-v3-tiktoken_regularizer.txt | 11 + .../fuji-8B-v3.txt | 6 +- .../fuji-8B-v3_init.txt | 4 +- .../fuji-golden-run-test-v3.txt | 4 +- .../fuji-test-v3-flash-tiktoken.txt | 289 ++++++++++++++ .../fuji-test-v3-flash-tiktoken_init.txt | 9 + ...uji-test-v3-flash-tiktoken_regularizer.txt | 10 + .../fuji-test-v3-flash.txt | 4 +- .../fuji-test-v3-tiktoken.txt | 254 +++++++++++++ .../fuji-test-v3-tiktoken_init.txt | 9 + .../fuji-test-v3-tiktoken_regularizer.txt | 10 + .../fuji-test-v3.txt | 4 +- .../Llama-3.1-70B.json | 34 -- .../Llama-3.1-8B.json | 34 -- .../Llama-3.2-1B.json | 35 -- .../Llama-3.2-3B.json | 35 -- axlearn/experiments/text/gpt/c4_trainer.py | 33 +- axlearn/experiments/text/gpt/common.py | 8 + axlearn/experiments/text/gpt/fuji.py | 35 +- .../text/gpt/param_converter_test.py | 110 ++++-- .../text/gpt/vocabulary_fuji_v3.py | 200 ++++++++++ .../text/gpt/vocabulary_fuji_v3_test.py | 200 ++++++++++ 107 files changed, 7445 insertions(+), 266 deletions(-) create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_regularizer.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-70B.json delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-8B.json delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-1B.json delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-3B.json create mode 100644 axlearn/experiments/text/gpt/vocabulary_fuji_v3.py create mode 100644 axlearn/experiments/text/gpt/vocabulary_fuji_v3_test.py diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt index 5f32cbbca..3797434cb 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' model.decoder.logits_partition_spec[0][1]: 'expert' @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -267,7 +267,7 @@ model.decoder.transformer.num_layers: 16 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt index 5c4658cf7..723be0837 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host.txt new file mode 100644 index 000000000..65eaee77e --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host.txt @@ -0,0 +1,287 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 16 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 16 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 16 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 2048 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 32 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 16 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host_init.txt new file mode 100644 index 000000000..5c4658cf7 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken.txt new file mode 100644 index 000000000..3a983d307 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken.txt @@ -0,0 +1,287 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 512 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 2048 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 32 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 16 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken_init.txt new file mode 100644 index 000000000..5c4658cf7 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt index 52c775e95..01a8c6f0e 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' model.decoder.logits_partition_spec[0][1]: 'expert' @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -267,7 +267,7 @@ model.decoder.transformer.num_layers: 16 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt index 5c4658cf7..723be0837 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt index 5667c0030..484767596 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' model.decoder.logits_partition_spec[0][1]: 'expert' @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -232,7 +232,7 @@ model.decoder.transformer.num_layers: 16 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_init.txt index 5c4658cf7..723be0837 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt new file mode 100644 index 000000000..5a19cb697 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt @@ -0,0 +1,252 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 16 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 16 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 16 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 2048 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 32 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 16 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_init.txt new file mode 100644 index 000000000..5c4658cf7 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt new file mode 100644 index 000000000..f910bb221 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt @@ -0,0 +1,252 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 512 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 2048 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 32 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 16 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_init.txt new file mode 100644 index 000000000..5c4658cf7 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt index 8f45a9e06..4a10c03dc 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' model.decoder.logits_partition_spec[0][1]: 'expert' @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -232,7 +232,7 @@ model.decoder.transformer.num_layers: 16 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3_init.txt index 5c4658cf7..723be0837 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt index 7610c22be..9335b37fa 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' model.decoder.logits_partition_spec[0][1]: 'expert' @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -267,7 +267,7 @@ model.decoder.transformer.num_layers: 28 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host_init.txt index b16c157f2..a42dd020d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host.txt new file mode 100644 index 000000000..efd458b5b --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host.txt @@ -0,0 +1,287 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 16 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 16 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 16 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 3072 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 24 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 28 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host_init.txt new file mode 100644 index 000000000..b16c157f2 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken.txt new file mode 100644 index 000000000..32a4e5414 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken.txt @@ -0,0 +1,287 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 512 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 3072 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 24 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 28 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken_init.txt new file mode 100644 index 000000000..b16c157f2 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt index 5c84380f7..7a0912769 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' model.decoder.logits_partition_spec[0][1]: 'expert' @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -267,7 +267,7 @@ model.decoder.transformer.num_layers: 28 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash_init.txt index b16c157f2..a42dd020d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt index 0184ae3fd..d22009529 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' model.decoder.logits_partition_spec[0][1]: 'expert' @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -232,7 +232,7 @@ model.decoder.transformer.num_layers: 28 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host_init.txt index b16c157f2..a42dd020d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt new file mode 100644 index 000000000..1cce0a2b8 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt @@ -0,0 +1,252 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 16 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 16 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 16 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 3072 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 24 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 28 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_init.txt new file mode 100644 index 000000000..b16c157f2 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt new file mode 100644 index 000000000..c4b52499d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt @@ -0,0 +1,252 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 512 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 3072 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 24 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 28 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_init.txt new file mode 100644 index 000000000..b16c157f2 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt index 0f7828bf0..c1a4b0b4e 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' model.decoder.logits_partition_spec[0][1]: 'expert' @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -232,7 +232,7 @@ model.decoder.transformer.num_layers: 28 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3_init.txt index b16c157f2..a42dd020d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host.txt new file mode 100644 index 000000000..4f3e7862d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host.txt @@ -0,0 +1,311 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 367001 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 367001 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 32 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 2048 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 367001 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 32 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 2048 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 32 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 2048 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 367001 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 367001 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5litepod-256-4' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[1][1][0]: 1 +mesh_rules[1][1][1]: -1 +mesh_rules[1][1][2]: 1 +mesh_rules[1][1][3]: 128 +mesh_rules[1][1][4]: 1 +mesh_rules[1][1][5]: 1 +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 1 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.dim: 8192 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_init.txt new file mode 100644 index 000000000..ab71a133e --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host.txt new file mode 100644 index 000000000..074855a49 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host.txt @@ -0,0 +1,276 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 367001 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 367001 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 32 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 2048 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 367001 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 32 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 2048 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 32 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 2048 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 367001 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 367001 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5litepod-256-4' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[1][1][0]: 1 +mesh_rules[1][1][1]: -1 +mesh_rules[1][1][2]: 1 +mesh_rules[1][1][3]: 128 +mesh_rules[1][1][4]: 1 +mesh_rules[1][1][5]: 1 +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 1 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.dim: 8192 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'MultiheadAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'MultiheadAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'MultiheadAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'MultiheadAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'MultiheadAttention.o_proj' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_init.txt new file mode 100644 index 000000000..ab71a133e --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host.txt new file mode 100644 index 000000000..857d879fe --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host.txt @@ -0,0 +1,312 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 524288 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 524288 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 16 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 4096 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 524288 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 16 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 4096 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 16 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 4096 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 524288 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 524288 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5litepod-256-4' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[1][1][0]: 1 +mesh_rules[1][1][1]: -1 +mesh_rules[1][1][2]: 1 +mesh_rules[1][1][3]: 128 +mesh_rules[1][1][4]: 1 +mesh_rules[1][1][5]: 1 +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 1 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.dim: 8192 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_init.txt new file mode 100644 index 000000000..2f13215e5 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host.txt new file mode 100644 index 000000000..b410021d8 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host.txt @@ -0,0 +1,277 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 524288 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 524288 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 16 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 4096 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 524288 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 16 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 4096 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 16 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 4096 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 524288 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 524288 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5litepod-256-4' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[1][1][0]: 1 +mesh_rules[1][1][1]: -1 +mesh_rules[1][1][2]: 1 +mesh_rules[1][1][3]: 128 +mesh_rules[1][1][4]: 1 +mesh_rules[1][1][5]: 1 +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 1 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.dim: 8192 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_init.txt new file mode 100644 index 000000000..2f13215e5 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host.txt new file mode 100644 index 000000000..03eb4153c --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host.txt @@ -0,0 +1,312 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 8 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'Llama-3-tokenizer.json' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 8 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'Llama-3-tokenizer.json' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 8 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'Llama-3-tokenizer.json' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5litepod-256-4' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[1][1][0]: 1 +mesh_rules[1][1][1]: -1 +mesh_rules[1][1][2]: 1 +mesh_rules[1][1][3]: 128 +mesh_rules[1][1][4]: 1 +mesh_rules[1][1][5]: 1 +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 1 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.dim: 8192 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_init.txt new file mode 100644 index 000000000..f0e1c9fec --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken.txt new file mode 100644 index 000000000..87aac848f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 512 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5litepod-256-4' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[1][1][0]: 1 +mesh_rules[1][1][1]: -1 +mesh_rules[1][1][2]: 1 +mesh_rules[1][1][3]: 128 +mesh_rules[1][1][4]: 1 +mesh_rules[1][1][5]: 1 +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 1 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 8192 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_init.txt new file mode 100644 index 000000000..f0e1c9fec --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt index 937f2c94e..d9cc82051 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt @@ -158,7 +158,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' model.decoder.lm_head.param_partition_spec[0]: None @@ -172,7 +172,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -293,7 +293,7 @@ model.decoder.transformer.num_layers: 80 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash_init.txt index f0e1c9fec..8730d5928 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) @@ -7,4 +7,4 @@ decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host.txt new file mode 100644 index 000000000..11d87ce4b --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host.txt @@ -0,0 +1,277 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 8 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'Llama-3-tokenizer.json' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 8 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'Llama-3-tokenizer.json' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 8 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'Llama-3-tokenizer.json' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5litepod-256-4' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[1][1][0]: 1 +mesh_rules[1][1][1]: -1 +mesh_rules[1][1][2]: 1 +mesh_rules[1][1][3]: 128 +mesh_rules[1][1][4]: 1 +mesh_rules[1][1][5]: 1 +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 1 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.dim: 8192 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_init.txt new file mode 100644 index 000000000..f0e1c9fec --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt new file mode 100644 index 000000000..4c5546f61 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt @@ -0,0 +1,278 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 512 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5litepod-256-4' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[1][1][0]: 1 +mesh_rules[1][1][1]: -1 +mesh_rules[1][1][2]: 1 +mesh_rules[1][1][3]: 128 +mesh_rules[1][1][4]: 1 +mesh_rules[1][1][5]: 1 +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 1 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 8192 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_init.txt new file mode 100644 index 000000000..f0e1c9fec --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt index eeabf4d15..71bc8aecd 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt @@ -158,7 +158,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' model.decoder.lm_head.param_partition_spec[0]: None @@ -172,7 +172,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -258,7 +258,7 @@ model.decoder.transformer.num_layers: 80 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3_init.txt index f0e1c9fec..8730d5928 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) @@ -7,4 +7,4 @@ decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt index fd009d652..07bfc434b 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt @@ -201,7 +201,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' model.decoder.lm_head.param_partition_spec[0]: None @@ -215,7 +215,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -336,7 +336,7 @@ model.decoder.transformer.num_layers: 32 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_init.txt index 02d2b8470..311e12ed7 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) @@ -7,4 +7,4 @@ decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host.txt new file mode 100644 index 000000000..ea468dacd --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host.txt @@ -0,0 +1,356 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 16 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 16 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 16 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v4-(1024|2048)' +mesh_rules[0][1][0]: 1 +mesh_rules[0][1][1]: -1 +mesh_rules[0][1][2]: 1 +mesh_rules[0][1][3]: 16 +mesh_rules[0][1][4]: 1 +mesh_rules[0][1][5]: 1 +mesh_rules[1][0]: 'tpu-v5litepod-256' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v5litepod-256-2' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[3][0]: 'tpu-v5litepod-256-4' +mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[3][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' +mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[4][0]: 'tpu-v5p-.*' +mesh_rules[4][1][0]: 1 +mesh_rules[4][1][1]: -1 +mesh_rules[4][1][2]: 1 +mesh_rules[4][1][3]: 8 +mesh_rules[4][1][4]: 1 +mesh_rules[4][1][5]: 1 +mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[5][1][0]: 1 +mesh_rules[5][1][1]: -1 +mesh_rules[5][1][2]: 1 +mesh_rules[5][1][3]: 8 +mesh_rules[5][1][4]: 1 +mesh_rules[5][1][5]: 1 +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 4096 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 32 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 32 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_init.txt new file mode 100644 index 000000000..02d2b8470 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken.txt new file mode 100644 index 000000000..0906e3fe0 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken.txt @@ -0,0 +1,356 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 512 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v4-(1024|2048)' +mesh_rules[0][1][0]: 1 +mesh_rules[0][1][1]: -1 +mesh_rules[0][1][2]: 1 +mesh_rules[0][1][3]: 16 +mesh_rules[0][1][4]: 1 +mesh_rules[0][1][5]: 1 +mesh_rules[1][0]: 'tpu-v5litepod-256' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v5litepod-256-2' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[3][0]: 'tpu-v5litepod-256-4' +mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[3][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' +mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[4][0]: 'tpu-v5p-.*' +mesh_rules[4][1][0]: 1 +mesh_rules[4][1][1]: -1 +mesh_rules[4][1][2]: 1 +mesh_rules[4][1][3]: 8 +mesh_rules[4][1][4]: 1 +mesh_rules[4][1][5]: 1 +mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[5][1][0]: 1 +mesh_rules[5][1][1]: -1 +mesh_rules[5][1][2]: 1 +mesh_rules[5][1][3]: 8 +mesh_rules[5][1][4]: 1 +mesh_rules[5][1][5]: 1 +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 4096 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 32 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 32 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_init.txt new file mode 100644 index 000000000..02d2b8470 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt index 5f8e74098..6113f9ad6 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt @@ -201,7 +201,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' model.decoder.lm_head.param_partition_spec[0]: None @@ -215,7 +215,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -336,7 +336,7 @@ model.decoder.transformer.num_layers: 32 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_init.txt index 02d2b8470..311e12ed7 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) @@ -7,4 +7,4 @@ decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt index 7d4fd3de4..3fdc09486 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt @@ -201,7 +201,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' model.decoder.lm_head.param_partition_spec[0]: None @@ -215,7 +215,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -301,7 +301,7 @@ model.decoder.transformer.num_layers: 32 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_init.txt index 02d2b8470..311e12ed7 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) @@ -7,4 +7,4 @@ decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt new file mode 100644 index 000000000..d79a53429 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt @@ -0,0 +1,321 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 16 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 16 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 16 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v4-(1024|2048)' +mesh_rules[0][1][0]: 1 +mesh_rules[0][1][1]: -1 +mesh_rules[0][1][2]: 1 +mesh_rules[0][1][3]: 16 +mesh_rules[0][1][4]: 1 +mesh_rules[0][1][5]: 1 +mesh_rules[1][0]: 'tpu-v5litepod-256' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v5litepod-256-2' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[3][0]: 'tpu-v5litepod-256-4' +mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[3][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' +mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[4][0]: 'tpu-v5p-.*' +mesh_rules[4][1][0]: 1 +mesh_rules[4][1][1]: -1 +mesh_rules[4][1][2]: 1 +mesh_rules[4][1][3]: 8 +mesh_rules[4][1][4]: 1 +mesh_rules[4][1][5]: 1 +mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[5][1][0]: 1 +mesh_rules[5][1][1]: -1 +mesh_rules[5][1][2]: 1 +mesh_rules[5][1][3]: 8 +mesh_rules[5][1][4]: 1 +mesh_rules[5][1][5]: 1 +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 4096 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 32 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 32 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_init.txt new file mode 100644 index 000000000..02d2b8470 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt new file mode 100644 index 000000000..8f86e1a84 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt @@ -0,0 +1,321 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 512 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3932160 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v4-(1024|2048)' +mesh_rules[0][1][0]: 1 +mesh_rules[0][1][1]: -1 +mesh_rules[0][1][2]: 1 +mesh_rules[0][1][3]: 16 +mesh_rules[0][1][4]: 1 +mesh_rules[0][1][5]: 1 +mesh_rules[1][0]: 'tpu-v5litepod-256' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v5litepod-256-2' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[3][0]: 'tpu-v5litepod-256-4' +mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[3][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' +mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[4][0]: 'tpu-v5p-.*' +mesh_rules[4][1][0]: 1 +mesh_rules[4][1][1]: -1 +mesh_rules[4][1][2]: 1 +mesh_rules[4][1][3]: 8 +mesh_rules[4][1][4]: 1 +mesh_rules[4][1][5]: 1 +mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[5][1][0]: 1 +mesh_rules[5][1][1]: -1 +mesh_rules[5][1][2]: 1 +mesh_rules[5][1][3]: 8 +mesh_rules[5][1][4]: 1 +mesh_rules[5][1][5]: 1 +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 4096 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 32 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 32 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_init.txt new file mode 100644 index 000000000..02d2b8470 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt index 1d56ef5f4..75ed23089 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt @@ -201,7 +201,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' model.decoder.lm_head.param_partition_spec[0]: None @@ -215,7 +215,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' @@ -301,7 +301,7 @@ model.decoder.transformer.num_layers: 32 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_init.txt index 02d2b8470..311e12ed7 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) @@ -7,4 +7,4 @@ decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3.txt index 134133cc9..a3e254a79 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3.txt @@ -132,7 +132,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' model.decoder.logits_partition_spec[0][1]: 'expert' @@ -143,7 +143,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken.txt new file mode 100644 index 000000000..e70342adc --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken.txt @@ -0,0 +1,289 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 3000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3000 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 500 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3000 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 1500 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 32 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 64 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3000 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 1500 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 32 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 64 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 32 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 64 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0006 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3000 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.01 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3000 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 8 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 16 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 2.6666666666666665 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 2 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 4 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 4 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken_init.txt new file mode 100644 index 000000000..61615aa53 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash.txt index cfc10714b..fe5c8758a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash.txt @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' model.decoder.logits_partition_spec[0][1]: 'expert' @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken.txt new file mode 100644 index 000000000..0822d244b --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken.txt @@ -0,0 +1,254 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 3000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3000 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 500 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3000 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 1500 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 32 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 64 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3000 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 1500 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 32 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 64 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 32 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 64 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0006 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3000 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.01 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3000 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 8 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 16 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 2.6666666666666665 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 2 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 4 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 4 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_init.txt new file mode 100644 index 000000000..61615aa53 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3.txt index 966d20e28..20c7634e7 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3.txt @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' model.decoder.logits_partition_spec[0][1]: 'expert' @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-70B.json b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-70B.json deleted file mode 100644 index 9d03fc1e0..000000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-70B.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 128000, - "eos_token_id": 128001, - "hidden_act": "silu", - "hidden_size": 8192, - "initializer_range": 0.02, - "intermediate_size": 28672, - "max_position_embeddings": 131072, - "mlp_bias": false, - "model_type": "llama", - "num_attention_heads": 64, - "num_hidden_layers": 80, - "num_key_value_heads": 8, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": { - "factor": 8.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3" - }, - "rope_theta": 500000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.43.0.dev0", - "use_cache": true, - "vocab_size": 128256 -} diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-8B.json b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-8B.json deleted file mode 100644 index cccf055d6..000000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-8B.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 128000, - "eos_token_id": 128001, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 14336, - "max_position_embeddings": 131072, - "mlp_bias": false, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 8, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": { - "factor": 8.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3" - }, - "rope_theta": 500000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.43.0.dev0", - "use_cache": true, - "vocab_size": 128256 -} diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-1B.json b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-1B.json deleted file mode 100644 index 83b8b2aeb..000000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-1B.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 128000, - "eos_token_id": 128001, - "head_dim": 64, - "hidden_act": "silu", - "hidden_size": 2048, - "initializer_range": 0.02, - "intermediate_size": 8192, - "max_position_embeddings": 131072, - "mlp_bias": false, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 16, - "num_key_value_heads": 8, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": { - "factor": 32.0, - "high_freq_factor": 4.0, - "low_freq_factor": 1.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3" - }, - "rope_theta": 500000.0, - "tie_word_embeddings": true, - "torch_dtype": "bfloat16", - "transformers_version": "4.45.0.dev0", - "use_cache": true, - "vocab_size": 128256 -} diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-3B.json b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-3B.json deleted file mode 100644 index 47d4a5aa6..000000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-3B.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 128000, - "eos_token_id": 128001, - "head_dim": 128, - "hidden_act": "silu", - "hidden_size": 3072, - "initializer_range": 0.02, - "intermediate_size": 8192, - "max_position_embeddings": 131072, - "mlp_bias": false, - "model_type": "llama", - "num_attention_heads": 24, - "num_hidden_layers": 28, - "num_key_value_heads": 8, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": { - "factor": 32.0, - "high_freq_factor": 4.0, - "low_freq_factor": 1.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3" - }, - "rope_theta": 500000.0, - "tie_word_embeddings": true, - "torch_dtype": "bfloat16", - "transformers_version": "4.45.0.dev0", - "use_cache": true, - "vocab_size": 128256 -} diff --git a/axlearn/experiments/text/gpt/c4_trainer.py b/axlearn/experiments/text/gpt/c4_trainer.py index 8c70422e3..a9d403b78 100644 --- a/axlearn/experiments/text/gpt/c4_trainer.py +++ b/axlearn/experiments/text/gpt/c4_trainer.py @@ -41,22 +41,29 @@ """ -from axlearn.common.config import InstantiableConfig, config_for_function +from axlearn.common.config import InstantiableConfig, config_for_class, config_for_function from axlearn.common.input_lm import lm_text_preprocessor from axlearn.common.utils import get_data_dir from axlearn.experiments.text.common import DataMixtureComponent, vocab from axlearn.experiments.text.gpt import fuji, gspmd from axlearn.experiments.text.gpt.common import mixture_train_input_source, tfds_input +from axlearn.experiments.text.gpt.vocabulary_fuji_v3 import FujiV3Vocabulary from axlearn.experiments.trainer_config_utils import TrainerConfigFn -# Sentencepiece vocabs generated from c4/en:3.0.1. -# See bpe_{32k,128k}.json for the sentencepiece settings. -_SENTENCEPIECE_MODEL_NAME = { - 32 * 1024: "bpe_32k_c4.model", - # TikToken is not yet supported, so we are using sentencepiece for now. - # Our new grain-based inputs can support TikToken in the future. - 128256: "bpe_128k_c4.model", -} + +def _vocab_cfg(size: int): + if size == 32 * 1024: + # Sentencepiece vocabs generated from c4/en:3.0.1. + # See bpe_{32k,128k}.json for the sentencepiece settings. + return config_for_function(vocab).set(sentencepiece_model_name="bpe_32k_c4.model") + if size == 128 * 1024: + return config_for_function(vocab).set(sentencepiece_model_name="bpe_128k_c4.model") + if size == 128256: + # TikToken. + return config_for_class(FujiV3Vocabulary).set(filename="Llama-3-tokenizer.json") + raise ValueError(f"size {size} tokenizer does not exist.") + + _train_data_mixture_components = [ DataMixtureComponent( name="c4/en:3.0.1", @@ -75,9 +82,7 @@ def _eval_input_sources( dataset_name="c4/en:3.0.1", split=split, is_training=False, - vocab_cfg=config_for_function(vocab).set( - sentencepiece_model_name=_SENTENCEPIECE_MODEL_NAME[vocab_size] - ), + vocab_cfg=_vocab_cfg(vocab_size), max_sequence_length=max_sequence_length, ) for name, split in (("train", "train[:8192]"), ("validation", "validation")) @@ -87,9 +92,7 @@ def _eval_input_sources( def _train_input_source(*, vocab_size: int, max_sequence_length: int) -> InstantiableConfig: source_cfg = config_for_function(mixture_train_input_source).set( data_mixture_components=_train_data_mixture_components, - vocab_cfg=config_for_function(vocab).set( - sentencepiece_model_name=_SENTENCEPIECE_MODEL_NAME[vocab_size] - ), + vocab_cfg=_vocab_cfg(vocab_size), max_sequence_length=max_sequence_length, preprocessor=config_for_function(lm_text_preprocessor).set(max_padding_fraction=0.5), ) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index d32120d25..e91fd1c9a 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -241,6 +241,8 @@ def model_config( ffn_structure: str = "prenorm", atten_structure: str = "prenorm", atten_logit_cap: Optional[float] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, ) -> causal_lm.Model.Config: """Returns an LM model config based on the given hyperparams. @@ -271,6 +273,8 @@ def model_config( atten_logit_cap: Cap the absolute values of logits by tanh. Enabled by setting a positive value. remat_offload_dst: Destination of remat checkptoing offloading. + pad_token_id: Int ID of the inputs to be masked for self-attention. + eos_token_id: Int ID of the end of sequence token id. Returns: A causal LM config. @@ -301,6 +305,10 @@ def model_config( lm_head=lm_head_cfg, dropout_rate=dropout_rate, ) + if pad_token_id: + decoder_cfg.set(pad_token_id=pad_token_id) + if eos_token_id: + decoder_cfg.set(eos_token_id=eos_token_id) # Model. model_param_init = DefaultInitializer.default_config().set( init_by_param_name={ diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b1102..4473c08ef 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -67,7 +67,7 @@ class Version(enum.Enum): VOCAB_SIZE = { Version.V1: 32 * 1024, Version.V2: 32 * 1024, - Version.V3: 128256, + Version.V3: 128 * 1024, } @@ -117,8 +117,6 @@ def get_trainer_kwargs( ) -> dict[str, Any]: """Construct default trainer kwargs given a model size.""" tokens_per_batch = 4 * (1024**2) # 4M tokens. - if model_size not in TOTAL_TOKENS[version]: - return {} max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch max_sequence_length = MAX_SEQUENCE_LENGTH[version] train_batch_size = tokens_per_batch // max_sequence_length @@ -423,6 +421,9 @@ def get_trainer_kwargs( raise NotImplementedError(f"Unknown model size {model_size}.") model_kwargs = trainer_kwargs.pop("model_kwargs") model_kwargs.setdefault("vocab_size", vocab_size) + if version == Version.V3: + model_kwargs["pad_token_id"] = 128004 + model_kwargs["eos_token_id"] = 128001 trainer_kwargs["model_cfg"] = model_config(**model_kwargs) trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config( max_step=trainer_kwargs["max_step"], @@ -445,6 +446,8 @@ def model_config( ffn_dim: Optional[Union[int, config.FunctionConfigBase]] = None, flash_attention: bool = False, stack_cfg: Optional[BaseStackedTransformerLayer.Config] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, ) -> causal_lm.Model.Config: """Returns an LM model config based on the given hyperparams. @@ -500,6 +503,8 @@ def model_config( lm_head_cfg=LmHead.default_config() if not shared_lm_head else None, attention_cfg=flash_attention_config() if flash_attention else atten_cfg, attention_qkv_linear=atten_qkv_linear, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, ) return cfg @@ -515,21 +520,24 @@ def trainer_configs( """ arch = "fuji" config_map = {} - for version, model_size, flash_attention in itertools.product( - Version, MODEL_SIZES, [True, False] + for version, model_size, flash_attention, tiktoken in itertools.product( + Version, MODEL_SIZES, [True, False], [True, False] ): + if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. + continue + if version != Version.V3 and tiktoken: # Only V3 has TikToken option. + continue + suffix = "-flash" if flash_attention else "" vocab_size = VOCAB_SIZE[version] + if tiktoken: + suffix += "-tiktoken" + vocab_size = 128256 config_name = make_config_name( - arch=arch, - model_size=model_size, - version=f"v{version.value}", - suffix="-flash" if flash_attention else "", + arch=arch, model_size=model_size, version=f"v{version.value}", suffix=suffix ) kwargs = get_trainer_kwargs( model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention ) - if len(kwargs) == 0: # This combination does not exist - continue max_sequence_length = kwargs.pop("max_sequence_length") # pylint: disable-next=unexpected-keyword-arg,missing-kwoa config_map[config_name] = get_trainer_config_fn( @@ -538,7 +546,10 @@ def trainer_configs( max_sequence_length=max_sequence_length, ), evalers=evaler_config_dict( - eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length), + eval_input_sources( + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + ), ), **kwargs, ) diff --git a/axlearn/experiments/text/gpt/param_converter_test.py b/axlearn/experiments/text/gpt/param_converter_test.py index ce4de4bf6..9d67d0fce 100644 --- a/axlearn/experiments/text/gpt/param_converter_test.py +++ b/axlearn/experiments/text/gpt/param_converter_test.py @@ -9,7 +9,7 @@ import pytest import torch from absl.testing import absltest, parameterized -from transformers import AutoConfig +from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaForCausalLM from axlearn.common import utils @@ -27,6 +27,65 @@ # Use cpu for the test. jax.config.update("jax_platform_name", "cpu") +config_dict_1b = { + "vocab_size": 128256, + "hidden_size": 2048, + "intermediate_size": 8192, + "num_hidden_layers": 16, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "hidden_act": "silu", + "max_position_embeddings": 131072, + "initializer_range": 0.02, + "rms_norm_eps": 1e-5, + "use_cache": True, + "bos_token_id": 128000, + "eos_token_id": 128001, + "pretraining_tp": 1, + "tie_word_embeddings": True, + "rope_theta": 500000.0, + "rope_scaling": { + "factor": 32.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + "attention_bias": False, + "attention_dropout": 0.0, + "mlp_bias": False, + "torch_dtype": "bfloat16", + "architectures": ["LlamaForCausalLM"], +} +config_dict_3b = {"hidden_size": 3072, "num_attention_heads": 24, "num_hidden_layers": 28} +config_dict_8b = { + "hidden_size": 4096, + "intermediate_size": 14336, + "num_hidden_layers": 32, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + "tie_word_embeddings": False, +} +config_dict_70b = { + "hidden_size": 8192, + "intermediate_size": 28672, + "num_attention_heads": 64, + "num_hidden_layers": 80, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + "tie_word_embeddings": False, +} + def compute_fuji_grad(prng_key, fuji: Model, state: NestedTensor, input_batch: NestedTensor): """Compute gradient of fuji model with a pseudo loss.""" @@ -76,25 +135,13 @@ def compute_llama_grad(llama, torch_ids, state): class FujiConvertStateTest(TestCase): @parameterized.parameters( - dict( - fuji_model_name="fuji-1B-v3", - llama_model_name="Llama-3.2-1B", - ), - dict( - fuji_model_name="fuji-3B-v3", - llama_model_name="Llama-3.2-3B", - ), - dict( - fuji_model_name="fuji-8B-v3", - llama_model_name="Llama-3.1-8B", - ), - dict( - fuji_model_name="fuji-70B-v3", - llama_model_name="Llama-3.1-70B", - ), + dict(fuji_model_name="fuji-1B-v3-tiktoken"), + dict(fuji_model_name="fuji-3B-v3-tiktoken"), + dict(fuji_model_name="fuji-8B-v3-tiktoken"), + dict(fuji_model_name="fuji-70B-v3-tiktoken"), ) @pytest.mark.high_cpu - def test_weight_loading(self, fuji_model_name, llama_model_name): + def test_weight_loading(self, fuji_model_name): trainer_config_map = c4_trainer.named_trainer_configs() trainer_config_fn = trainer_config_map[fuji_model_name] trainer_config = trainer_config_fn() @@ -103,17 +150,14 @@ def test_weight_loading(self, fuji_model_name, llama_model_name): fuji: Model = model_config.instantiate(parent=None) prng_key = jax.random.PRNGKey(0) state = fuji.initialize_parameters_recursively(prng_key=prng_key) - config = AutoConfig.from_pretrained( - os.path.join( - dir_path, - "..", - "..", - "testdata", - "axlearn.experiments.text.gpt.param_converter_test", - f"{llama_model_name}.json", - ), - local_files_only=True, - ) + config_dict = config_dict_1b + if fuji_model_name == "fuji-3B-v3-tiktoken": + config_dict.update(config_dict_3b) + elif fuji_model_name == "fuji-8B-v3-tiktoken": + config_dict.update(config_dict_8b) + elif fuji_model_name == "fuji-70B-v3-tiktoken": + config_dict.update(config_dict_70b) + config = LlamaConfig(**config_dict) llama = LlamaForCausalLM._from_config(config) # pylint: disable=W0212 llama = llama.eval() ids = jax.random.randint(jax.random.PRNGKey(123), shape=(2, 2), minval=0, maxval=128256) @@ -135,13 +179,13 @@ def test_weight_loading(self, fuji_model_name, llama_model_name): llama_logits = output.logits.numpy() # The difference is caused by the SDPA attention layer. The deeper the larger the error. - if fuji_model_name == "fuji-1B-v3": + if fuji_model_name == "fuji-1B-v3-tiktoken": atol = 2e-3 - elif fuji_model_name == "fuji-3B-v3": + elif fuji_model_name == "fuji-3B-v3-tiktoken": atol = 2e-2 - elif fuji_model_name == "fuji-8B-v3": + elif fuji_model_name == "fuji-8B-v3-tiktoken": atol = 2e-1 - elif fuji_model_name == "fuji-70B-v3": + elif fuji_model_name == "fuji-70B-v3-tiktoken": atol = 2.0 else: atol = 2e-3 diff --git a/axlearn/experiments/text/gpt/vocabulary_fuji_v3.py b/axlearn/experiments/text/gpt/vocabulary_fuji_v3.py new file mode 100644 index 000000000..f60e6f701 --- /dev/null +++ b/axlearn/experiments/text/gpt/vocabulary_fuji_v3.py @@ -0,0 +1,200 @@ +# Copyright © 2024 Apple Inc. + +"""Fuji v3 vocabulary.""" + +import os +import tempfile +from typing import Optional, Protocol, Sequence, Union + +import jax +import numpy as np +import tensorflow.compat.v2 as tf +from tokenizers import Tokenizer + +import axlearn.common.file_system as fs +from axlearn.common.utils import get_data_dir + + +class InnerTokenizer(Protocol): + """Defines a protocol of InnerTokenizer which is used in Vocabulary. + + This is a subset of sentencepiece_processor.SentencePieceProcessor API that are used in + Vocabulary. + """ + + def encode_as_pieces(self, pieces: str) -> list[str]: + """Encode text input to tokens.""" + pass + + def piece_to_id(self, piece: str) -> int: + """Encode a token to id.""" + pass + + +class Vocabulary(Protocol): + """Defines a protocol of Vocabulary. + + This is a subset of seqio.Vocabulary APIs that are used in text_to_lm_training_input and + test_to_lm_eval_input. + """ + + @property + def pad_id(self) -> int: + pass + + @property + def eos_id(self) -> Optional[int]: + pass + + def encode_tf(self, s: tf.Tensor) -> tf.Tensor: + """Tokenizes string Scalar to an int32 Tensor, without adding EOS.""" + pass + + def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor: + """Detokenizes int32 batched Tensor.""" + pass + + def encode(self, s: str) -> list[int]: + """Tokenizes string to an int sequence, without adding EOS.""" + pass + + def _decode(self, ids: Sequence[int]) -> str: + """Detokenizes int sequence to a string, through all EOS.""" + pass + + def decode(self, ids: Sequence[int]) -> str: + """Detokenizes int32 iterable to a string, up through first EOS.""" + pass + + @property + def tokenizer(self) -> InnerTokenizer: + pass + + +class FujiInnerTokenizer: + """A wrapper for tokenizer.Tokenizer so that it follows InnerTokenizer Protocol.""" + + def __init__(self, tokenizer): + self._tokenizer = tokenizer + + def encode_as_pieces(self, pieces: str) -> list[str]: + """Encode text input to tokens.""" + return self._tokenizer.encode(pieces, add_special_tokens=False).tokens + + def piece_to_id(self, piece: str) -> int: + """Encode a token to id.""" + return self._tokenizer.token_to_id(piece) + + +class FujiV3Vocabulary: + """A wrapper for tokenizers.Tokenizer so that it follows Vocabulary Protocol. + + Although its name has fuji, but it can be extended to work for all tokenizers.Tokenizer. + """ + + def __init__(self, filename: str): + data_dir = get_data_dir() + data_dir = ( + os.path.join(os.path.dirname(__file__), "..", "..", "..", "data") + if data_dir is None or data_dir == "FAKE" + else data_dir + ) + filename = os.path.join(data_dir, "tokenizers", "hf", filename) + if filename.startswith("gs:") or filename.startswith("s3:"): + # Create a different file for each usage. + tmp = tempfile.mkdtemp() + path = os.path.join(tmp, "tokenizer.json") + fs.copy(filename, path) + filename = path + self._tokenizer = Tokenizer.from_file(filename) + self.vocab = self._tokenizer.get_vocab() + self.tokenizer = FujiInnerTokenizer(self._tokenizer) + + @property + def pad_id(self) -> int: + # Some tokenizers do not have a pad_id. + # https://discuss.huggingface.co/t/how-to-set-the-pad-token-for-meta-llama-llama-3-models/103418 + for token in ("<|pad_id|>", "<|finetune_right_pad_id|>"): + if token in self.vocab: + return self.vocab[token] + else: + raise ValueError("Unable to infer pad token.") + + @property + def eos_id(self) -> Optional[int]: + if "<|end_of_text|>" in self.vocab: + return self.vocab["<|end_of_text|>"] + else: + raise NotImplementedError() + + def _encode_tf(self, s: tf.Tensor) -> tf.Tensor: + """Encodes a string to token IDs. + + Args: + s: A tf.Tensor of shape () or (n,) and dtype tf.string. + + Returns: + A tf.Tensor or RaggedTensor of shape (num_tokens,) or (n, None) and dtype tf.int32. + """ + need_unpack = False + if s.ndim == 0: + s = tf.reshape(s, (1,)) + need_unpack = True + + def helper(s): + s = s.numpy() + res = self._tokenizer.encode_batch([item.decode("utf-8") for item in s]) + return tf.ragged.constant([r.ids for r in res]) + + ret = tf.py_function( + helper, inp=[s], Tout=tf.RaggedTensorSpec([None, None], dtype=tf.int32) + ) + if need_unpack: + return ret[0] + else: + return ret + + def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor: + """Detokenizes int32 batched Tensor.""" + need_unpack = False + if ids.ndim == 1: + ids = tf.reshape(ids, (1, -1)) + need_unpack = True + + def helper(ids): + s = self._tokenizer.decode_batch(ids.numpy().tolist(), skip_special_tokens=True) + return tf.convert_to_tensor(s, dtype=tf.string) + + ret = tf.py_function(helper, inp=[ids], Tout=tf.string) + ret.set_shape(tf.TensorShape((None,))) + if need_unpack: + return ret[0] + else: + return ret + + def encode_tf(self, s: tf.Tensor) -> tf.Tensor: + """Tokenizes string Scalar to an int32 Tensor, without adding EOS. + + Args: + s: A tf.Tensor of shape () or (n,) and dtype tf.string. + + Returns: + A tf.Tensor or RaggedTensor of shape (num_tokens,) or (n, None) and dtype tf.int32. + """ + return self._encode_tf(s) + + def encode(self, s: str) -> list[int]: + """Tokenizes string to an int sequence, without adding EOS.""" + return self._tokenizer.encode(s).ids + + def _decode(self, ids: Union[list[int], tuple[int]]) -> str: + """Detokenizes int32 iterable to a string.""" + return self._tokenizer.decode(ids) + + def decode(self, ids: Union[list[int], tuple[int], jax.Array, np.ndarray]) -> str: + """Detokenizes int32 iterable to a string, up through first EOS.""" + if self.eos_id is not None and self.eos_id in ids: + if isinstance(ids, (jax.Array, np.ndarray)): + ids = ids.tolist() # type: ignore + ids = ids[: ids.index(self.eos_id) + 1] + return self._decode(ids) diff --git a/axlearn/experiments/text/gpt/vocabulary_fuji_v3_test.py b/axlearn/experiments/text/gpt/vocabulary_fuji_v3_test.py new file mode 100644 index 000000000..50eb56f4a --- /dev/null +++ b/axlearn/experiments/text/gpt/vocabulary_fuji_v3_test.py @@ -0,0 +1,200 @@ +# Copyright © 2024 Apple Inc. + +"""Tests fuji v3 vocabulary.""" + +import numpy as np +import pytest +import tensorflow.compat.v2 as tf +from absl.testing import parameterized + +from axlearn.common import input_text, input_tf_data +from axlearn.common.config import config_for_class, config_for_function +from axlearn.common.input_lm import ( + PackingMethodType, + lm_text_preprocessor, + text_to_lm_eval_input, + text_to_lm_training_input, +) +from axlearn.common.input_text_test import make_ds_fn +from axlearn.common.test_utils import TestCase +from axlearn.experiments.text.gpt.vocabulary_fuji_v3 import FujiV3Vocabulary + + +@pytest.mark.skip(reason="no tokenizer file.") +class FujiV3VocabularyTest(TestCase): + """Tests FujiV3VocabularyTest.""" + + @property + def vocab_cfg(self): + return config_for_class(FujiV3Vocabulary).set(filename="Llama-3-tokenizer.json") + + def test_encode_tf_and_decode_tf(self): + vocab = self.vocab_cfg.instantiate() + text = tf.constant( + "Lorem ipsum dolor sit amet, consectetur adipiscing elit\n", dtype=tf.string + ) + ids = vocab.encode_tf(text) + recovered = vocab._decode_tf(ids) # pylint: disable=W0212 + + self.assertEqual(text.numpy().decode("utf-8"), recovered.numpy().decode("utf-8")) + + def test_tokenize_example(self): + vocab = self.vocab_cfg.instantiate() + newlines_replaced_with = "" + newlines_replaced_with_id = vocab.encode(newlines_replaced_with)[1:] # remove bos token + + # Test tokenize_example replaces newlines. + tokens = input_text.tokenize_example( + "Hello\n", sp_vocab=vocab, replace_newlines_with=newlines_replaced_with + ).numpy() + self.assertNestedAllClose( + np.array(vocab.encode("Hello") + newlines_replaced_with_id), tokens + ) + + def test_num_bytes(self): + vocab = self.vocab_cfg.instantiate() + newlines_replaced_with = "\n" + pad_id = vocab.pad_id + newline_id = vocab.encode("\n").pop() + newlines_replaced_with_id = vocab.encode(newlines_replaced_with).pop() + + # Test num_bytes computes expected value. + ids = tf.constant( + [vocab.eos_id, newlines_replaced_with_id, newline_id, pad_id, pad_id, pad_id], + dtype=tf.int32, + ) + self.assertEqual( + 3, + input_text.num_bytes( + ids, sp_vocab=vocab, newlines_replaced_with=newlines_replaced_with + ), + ) + + @parameterized.parameters( + dict( + packing_method=PackingMethodType.EOS_DELIM_MASK, + max_padding_fraction=1.0, # Always pad + ), + dict( + packing_method=PackingMethodType.EOS_DELIM_NO_MASK, + max_padding_fraction=1.0, # Always pad + ), + dict( + packing_method=PackingMethodType.EOS_DELIM_MASK, + max_padding_fraction=0.0, # Do not pad + ), + ) + def test_fake_text_lm_training_data( + self, packing_method: PackingMethodType, max_padding_fraction: float + ): + texts = [ + "hello world\n", + "hello moon\n", + ] + + # window_size > len(texts) to repeat the sentence. 18 tokens in total. + window_size = 3 + + # Pad the concatenated sequence to 20 tokens: + # Or, trim the sequence to 15 tokens: + batch_size, max_len = 2, 5 + + # Disable shuffling to make results interpretable. + shuffle_buffer_size = 0 + + # Test text_to_lm_training_input. + cfg = input_tf_data.Input.default_config().set( + name="test_input", + is_training=True, + source=config_for_function(make_ds_fn).set(texts=texts), + processor=config_for_function(text_to_lm_training_input).set( + vocab_cfg=self.vocab_cfg, + max_len=max_len, + replace_newlines_with="", + window_size=window_size, + max_padding_fraction=max_padding_fraction, + shuffle_buffer_size=shuffle_buffer_size, + packing_method=packing_method, + ), + batcher=config_for_function(input_tf_data.batch).set( + global_batch_size=batch_size, + prefetch_buffer_size=2, + pad_example_fn=input_tf_data.default_pad_example_fn, + ), + ) + + # Set TensorFlow seed. + tf.random.set_seed(123) + dataset = cfg.instantiate(parent=None) + for ix, batch in enumerate(dataset): + self.assertIsNotNone(batch) + batch = {k: v.tolist() for k, v in batch.items()} + if ix >= 10: + # Expect to be able to repeat forever. + break + + # Test lm_text_preprocessor. Expect same results. + cfg = input_tf_data.Input.default_config().set( + name="test_input", + is_training=True, + source=config_for_function(make_ds_fn).set( + texts=texts, + ), + processor=config_for_function(lm_text_preprocessor).set( + vocab_cfg=self.vocab_cfg, + max_sequence_length=max_len, + replace_newlines_with="", + window_size=window_size, + max_padding_fraction=max_padding_fraction, + shuffle_buffer_size=shuffle_buffer_size, + packing_method=packing_method, + ), + batcher=config_for_function(input_tf_data.batch).set( + global_batch_size=batch_size, + prefetch_buffer_size=2, + pad_example_fn=input_tf_data.default_pad_example_fn, + ), + ) + + # Reset TensorFlow seed. + tf.random.set_seed(123) + dataset = cfg.instantiate(parent=None) + for ix, batch in enumerate(dataset): + if ix >= 3: + break + batch = {k: v.tolist() for k, v in batch.items()} + + @parameterized.parameters( + ("How long is a piece of string?", "index"), + ("On the 20th of June", "not_index"), + ("Here we stand united", None), + ) + def test_eval_lm_processor_single_example(self, text, index_key): + max_len = 12 + processor = text_to_lm_eval_input( + vocab_cfg=self.vocab_cfg, + max_len=max_len, + replace_newlines_with="\n", + stride=None, + index_key="index", + ) + ds_fn = ( + config_for_function(make_ds_fn) + .set(is_training=False, texts=[text], repeat=1) + .instantiate() + ) + example = next(iter(processor(ds_fn()))) + for key in ["input_ids", "target_labels"]: + # Shape is as expected. + self.assertEqual((max_len,), example[key].numpy().shape) + self.assertTrue("target_num_bytes" in example) + # Index should have been passed through only for set value of `index_key`. + self.assertEqual(index_key == "index", index_key in example) + + input_ids, target_labels = example["input_ids"].numpy(), example["target_labels"].numpy() + self.assertEqual(128000, input_ids[1]) # BOS + non_padded_length = (target_labels == 128004).argmax() + self.assertNotEqual(128001, target_labels[0]) # No EOS at start. + self.assertEqual(128001, target_labels[non_padded_length - 1]) # EOS. + # The inputs should be one-off the labels. + self.assertNestedAllClose(target_labels[:-1], input_ids[1:])