diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 37baf3d8b..fedcd51fb 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -3956,15 +3956,22 @@ def policy(prim, *_, **params): return policy -SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)" -FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*" +# Regex patterns for matching remat names +class RematRegexSavePatterns(enum.Enum): + QKV_PROJ = r".*[kqv]_proj" + O_PROJ = r".*o_proj" + CONTEXT = r".*context" + LINEAR1_X = r".*linear1_[01]" + LINEAR2_X = r".*linear2_[01]" + SELF_ATTENTION = re.compile("|".join([QKV_PROJ, O_PROJ, CONTEXT])).pattern + FEED_FORWARD = re.compile("|".join([LINEAR1_X, LINEAR2_X])).pattern def build_remat_spec( stack_cfg: Union[ BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore ], - save_pattern: _SavePattern = SELF_ATTENTION_SAVE_PATTERN, + save_pattern: _SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value, offload_pattern: _SavePattern = None, offload_dst: str = "pinned_host", ) -> Optional[RematSpec]: diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 1e188ecc0..268649c10 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -41,7 +41,6 @@ from axlearn.common import attention, attention_bias, test_utils, utils from axlearn.common.attention import ( - FEED_FORWARD_SAVE_PATTERN, BaseStackedTransformerLayer, BaseTransformerLayer, BottleNeckAdapterTransformerLayer, @@ -58,6 +57,7 @@ PipelinedTransformerLayer, QKVLinear, QLinear, + RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, StackedTransformerLayer, @@ -3420,7 +3420,7 @@ def f(x, layer_params): jax.remat( f, policy=_save_and_offload_only_these_names_regex( - names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN, + names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value, names_which_can_be_offloaded=None, offload_src="device", offload_dst="pinned_host", @@ -3875,6 +3875,70 @@ def f(x, layer_params): 5, ) + def test_build_remat_spec_neuron(self): + model_dim, num_heads = 6, 2 + cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim) + cfg.self_attention.attention.set(num_heads=num_heads, causal=True) + cfg.feed_forward.hidden_dim = model_dim * 4 + cfg.vlog = 5 + + layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + + batch_size, tgt_len = 2, 5 + rng = np.random.default_rng(seed=123) + target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) + + def f(x, layer_params): + forward_outputs, _ = F( + layer, + inputs=dict( + data=x, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + return forward_outputs + + # Ignore type errors. + spec: Any = build_remat_spec(mock.MagicMock()) + + policy = ( + config_for_function(_save_and_offload_only_these_names_regex) + .set( + names_which_can_be_saved="|".join( + [ + RematRegexSavePatterns.QKV_PROJ.value, + RematRegexSavePatterns.LINEAR1_X.value, + ] + ), + names_which_can_be_offloaded=None, + offload_src=None, + offload_dst=None, + ) + .instantiate() + ) + + _, default_policy_backward = jax.linearize( + jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse), + jnp.asarray(target), + layer_params, + ) + _, full_remat_backward = jax.linearize( + jax.remat(f), + jnp.asarray(target), + layer_params, + ) + + # Eliminated the remat of qkv_proj and linear1_0 = 4 dots. This assumes + # FlashAttention is not enabled. + self.assertEqual( + str(full_remat_backward).count(" dot_general") + - str(default_policy_backward).count(" dot_general"), + 4, + ) + class TestStackModel(BaseLayer): """A dummy transformer stack.""" diff --git a/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-l-rnnt.txt b/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-l-rnnt.txt index db50dbf1e..e8c5f4c26 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-l-rnnt.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-l-rnnt.txt @@ -485,7 +485,7 @@ model.encoder.context.context.num_layers: 17 model.encoder.context.context.remat_spec['prevent_cse']: False model.encoder.context.context.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.encoder.context.context.remat_spec['policy'].names_which_can_be_offloaded: None -model.encoder.context.context.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.encoder.context.context.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.encoder.context.context.remat_spec['policy'].offload_dst: 'pinned_host' model.encoder.context.context.remat_spec['policy'].offload_src: 'device' model.encoder.context.context.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' diff --git a/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-test-ctc.txt b/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-test-ctc.txt index b4b4a909c..e05044a98 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-test-ctc.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-test-ctc.txt @@ -188,7 +188,7 @@ model.encoder.context.context.num_layers: 1 model.encoder.context.context.remat_spec['prevent_cse']: False model.encoder.context.context.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.encoder.context.context.remat_spec['policy'].names_which_can_be_offloaded: None -model.encoder.context.context.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.encoder.context.context.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.encoder.context.context.remat_spec['policy'].offload_dst: 'pinned_host' model.encoder.context.context.remat_spec['policy'].offload_src: 'device' model.encoder.context.context.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt index f8e50909a..109b27352 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt index 2a831b28c..2db97729b 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt index 0a470ccea..611b197c0 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt index c4c6eed38..b9358475b 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt index d06cfb3c7..8193253eb 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt index 8d5dc4e92..0de0bb45f 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt index 8d7e8f710..8af3901eb 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt index 53ef5d052..6e8b647e0 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt index ade5f1af2..382362b84 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt @@ -135,6 +135,22 @@ mesh_rules[1][1][2]: 1 mesh_rules[1][1][3]: 128 mesh_rules[1][1][4]: 1 mesh_rules[1][1][5]: 1 +mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 @@ -214,7 +230,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt index a986f1d08..c8e0ecf8e 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt @@ -135,6 +135,22 @@ mesh_rules[1][1][2]: 1 mesh_rules[1][1][3]: 128 mesh_rules[1][1][4]: 1 mesh_rules[1][1][5]: 1 +mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 @@ -214,7 +230,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt index 03fc3428a..0dca36c79 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt @@ -135,6 +135,22 @@ mesh_rules[1][1][2]: 1 mesh_rules[1][1][3]: 128 mesh_rules[1][1][4]: 1 mesh_rules[1][1][5]: 1 +mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 @@ -214,7 +230,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt index 1ecf7529f..0a497518d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt @@ -135,6 +135,22 @@ mesh_rules[1][1][2]: 1 mesh_rules[1][1][3]: 128 mesh_rules[1][1][4]: 1 mesh_rules[1][1][5]: 1 +mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 @@ -214,7 +230,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt index 76193c0db..df6a44563 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt @@ -135,6 +135,22 @@ mesh_rules[1][1][2]: 1 mesh_rules[1][1][3]: 128 mesh_rules[1][1][4]: 1 mesh_rules[1][1][5]: 1 +mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 @@ -214,7 +230,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt index 45bdb8e66..da2640d58 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt @@ -135,6 +135,22 @@ mesh_rules[1][1][2]: 1 mesh_rules[1][1][3]: 128 mesh_rules[1][1][4]: 1 mesh_rules[1][1][5]: 1 +mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 @@ -214,7 +230,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt index 98cd9261c..a8da85c51 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt @@ -254,7 +254,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt index 0a62cc2b1..fbb40bb06 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt @@ -254,7 +254,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt index a9e1f38ed..dfab6421e 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt @@ -254,7 +254,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt index 87736a6f5..9d8523972 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt @@ -254,7 +254,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt index e01051cac..eb2eb01ac 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt @@ -254,7 +254,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt index 964f23e23..1ce7de634 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt @@ -254,7 +254,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt index 17f97ab30..04ed987b4 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt @@ -254,7 +254,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt index 438da62a1..590fc7bac 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt @@ -254,7 +254,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt index a5b50a240..e1a687ebf 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt @@ -257,7 +257,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt index da5826693..4089710af 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt @@ -257,7 +257,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt index 811b565e5..87fa85ca4 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt @@ -257,7 +257,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt index b71e46c9d..798229109 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt @@ -257,7 +257,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v1.txt index c361cd018..e55f53182 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v1.txt @@ -185,7 +185,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v2.txt index c361cd018..e55f53182 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v2.txt @@ -185,7 +185,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3.txt index c759ca4c7..7d16a8c47 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3.txt @@ -185,7 +185,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1-flash.txt index b65d70147..8ff237ed0 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1-flash.txt @@ -190,7 +190,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1.txt index 385e22dc7..398c2a251 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v1.txt @@ -190,7 +190,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2-flash.txt index b65d70147..8ff237ed0 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2-flash.txt @@ -190,7 +190,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2.txt index 385e22dc7..398c2a251 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v2.txt @@ -190,7 +190,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash.txt index 30dbdba03..a9c28c2cd 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash.txt @@ -190,7 +190,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3.txt index 221701f53..daaf4c38d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3.txt @@ -190,7 +190,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8-stream.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8-stream.txt index 460fdfa61..42754fa6d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8-stream.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8-stream.txt @@ -197,7 +197,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8.txt index d300cd1d8..b62c2ce4a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8.txt @@ -197,7 +197,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.deterministic_trainer/honeycrisp-3B-pajama-15t-49k.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.deterministic_trainer/honeycrisp-3B-pajama-15t-49k.txt index ab0734137..a363b0bc9 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.deterministic_trainer/honeycrisp-3B-pajama-15t-49k.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.deterministic_trainer/honeycrisp-3B-pajama-15t-49k.txt @@ -220,7 +220,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.deterministic_trainer/honeycrisp-85M-pajama-15t-49k.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.deterministic_trainer/honeycrisp-85M-pajama-15t-49k.txt index 6cea4cfbc..d61af12ef 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.deterministic_trainer/honeycrisp-85M-pajama-15t-49k.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.deterministic_trainer/honeycrisp-85M-pajama-15t-49k.txt @@ -222,7 +222,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.deterministic_trainer/honeycrisp-test-pajama-15t-49k.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.deterministic_trainer/honeycrisp-test-pajama-15t-49k.txt index 8cfd8c936..970a762e6 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.deterministic_trainer/honeycrisp-test-pajama-15t-49k.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.deterministic_trainer/honeycrisp-test-pajama-15t-49k.txt @@ -187,7 +187,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-1B-2k-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-1B-2k-sp-rp.txt index f596cb3c7..c5981d104 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-1B-2k-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-1B-2k-sp-rp.txt @@ -221,7 +221,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-85M-2k-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-85M-2k-sp-rp.txt index b3644816f..423a0518e 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-85M-2k-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_sigmoid_trainer/gala-sigmoid-85M-2k-sp-rp.txt @@ -221,7 +221,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-flash-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-flash-sp-rp.txt index 940c8055b..79190d6e1 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-flash-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-flash-sp-rp.txt @@ -220,7 +220,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-sp-rp.txt index a2e3c3bc0..e8a2c77a3 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-1B-sp-rp.txt @@ -220,7 +220,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-flash-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-flash-sp-rp.txt index 96cd67f46..dfdba2309 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-flash-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-flash-sp-rp.txt @@ -220,7 +220,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-sp-rp.txt index 03bb99416..7615d95b5 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-302M-sp-rp.txt @@ -220,7 +220,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-flash-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-flash-sp-rp.txt index 2447ee154..700762154 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-flash-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-flash-sp-rp.txt @@ -241,7 +241,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-sp-rp.txt index c52968df1..498d46e0b 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-7B-sp-rp.txt @@ -241,7 +241,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-flash-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-flash-sp-rp.txt index 27f971e4f..3221f0189 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-flash-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-flash-sp-rp.txt @@ -220,7 +220,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-sp-rp.txt index e728fabb1..4435b414c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-85M-sp-rp.txt @@ -220,7 +220,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-flash-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-flash-sp-rp.txt index 2cf60996d..650abe356 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-flash-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-flash-sp-rp.txt @@ -213,7 +213,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-sp-rp.txt index ee9ce6584..5eaa00f5c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/gala-test-sp-rp.txt @@ -213,7 +213,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-3B-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-3B-sp-rp.txt index c9c7dba57..c84e7e518 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-3B-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-3B-sp-rp.txt @@ -257,7 +257,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-85M-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-85M-sp-rp.txt index 25ddc20e1..60c73d288 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-85M-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-85M-sp-rp.txt @@ -259,7 +259,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-test-sp-rp.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-test-sp-rp.txt index 1e6a164c9..674950348 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-test-sp-rp.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.pajama_trainer/honeycrisp-test-sp-rp.txt @@ -224,7 +224,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye model.decoder.transformer.layer.remat_spec['prevent_cse']: False model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context' model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 973fb9234..c01789a39 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -34,6 +34,7 @@ BaseQKVLinear, MultiheadAttention, RepeatedTransformerLayer, + StackedTransformerLayer, TransformerLayer, build_remat_spec, set_double_shard_weights_config, @@ -190,20 +191,12 @@ def update_model_remat_config( ): """Recomputes and sets the remat_spec based on provided layer_cfg. - Only applied if the stack_cfg is a RepeatedTransformerLayer. - Args: stack_cfg: The transformer stack config. layer_cfg: The transformer layer config. offload_dst: Destination of remat checkptoing offloading. - Raises: - NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer. """ - if stack_cfg.klass is not RepeatedTransformerLayer: - raise NotImplementedError( - f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}" - ) remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg)) layer_cfg.set(remat_spec=remat_spec) @@ -277,7 +270,7 @@ def model_config( layer_cfg.self_attention.attention.input_linear = attention_qkv_linear layer_cfg.self_attention.structure = atten_structure layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap - if stack_cfg.klass is RepeatedTransformerLayer: + if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)): update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) # Stack. transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b1102..960db18ef 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -24,8 +24,10 @@ FusedQKVLinear, GroupedQueryAttention, MultiheadAttention, + RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, + _save_and_offload_only_these_names_regex, ) from axlearn.common.base_layer import RematSpec from axlearn.common.config import config_for_function @@ -85,7 +87,6 @@ class Version(enum.Enum): Version.V3: 5e5, } - # Mapping from Fuji versions to total number of tokens used in training. TOTAL_TOKENS = { Version.V1: { @@ -417,6 +418,36 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=config_for_function( + _save_and_offload_only_these_names_regex + ).set( + names_which_can_be_saved="|".join( + [ + RematRegexSavePatterns.QKV_PROJ.value, + RematRegexSavePatterns.LINEAR1_X.value, + ] + ), + names_which_can_be_offloaded=None, + offload_src=None, + offload_dst=None, + ), + ), + } + ), + ], + ), + ), ), ) else: